# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
import operator
import pytest
import torch
import numpy as np
from torch import nn
from torch.nn import Module
from torch.export import export

import tvm
from tvm import relax
import tvm.testing
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T
from tvm.relax.frontend.torch import from_exported_program


def verify_model(
    torch_model, example_args, binding, expected, dynamic_shapes=None, run_ep_decomposition=True
):
    exported_program = export(torch_model, args=example_args, dynamic_shapes=dynamic_shapes)
    mod = from_exported_program(exported_program, run_ep_decomposition=run_ep_decomposition)

    binding = {k: tvm.runtime.tensor(v) for k, v in binding.items()}
    expected = relax.transform.BindParams("main", binding)(expected)
    tvm.ir.assert_structural_equal(mod, expected)


operator_basic_unary = [
    (torch.abs, R.abs),
    (torch.acos, R.acos),
    (torch.acosh, R.acosh),
    (torch.asin, R.asin),
    (torch.asinh, R.asinh),
    (torch.atan, R.atan),
    (torch.atanh, R.atanh),
    (torch.bitwise_not, R.bitwise_not),
    (torch.ceil, R.ceil),
    (torch.cos, R.cos),
    (torch.cosh, R.cosh),
    (torch.erf, R.erf),
    (torch.exp, R.exp),
    (torch.floor, R.floor),
    (torch.ops.aten.gelu, R.nn.gelu),
    (torch.log, R.log),
    (torch.neg, R.negative),
    (torch.relu, R.nn.relu),
    (torch.round, R.round),
    (torch.rsqrt, R.rsqrt),
    (torch.sigmoid, R.sigmoid),
    (torch.sin, R.sin),
    (torch.sinh, R.sinh),
    (torch.sign, R.sign),
    (torch.sqrt, R.sqrt),
    (torch.tan, R.tan),
    (torch.tanh, R.tanh),
    (torch.trunc, R.trunc),
]


@pytest.mark.parametrize("pytorch_op, relax_op", operator_basic_unary)
def test_basic_unary_ops(pytorch_op, relax_op):
    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)

    class UnaryOp(Module):
        def forward(self, input):
            return pytorch_op(input)

    @tvm.script.ir_module
    class expected:
        @R.function
        def main(
            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = relax_op(input_1)
                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    verify_model(UnaryOp(), example_args, {}, expected)


operator_bool_unary = [
    (torch.isinf, R.isinf),
    (torch.isnan, R.isnan),
]


@pytest.mark.parametrize("pytorch_op, relax_op", operator_bool_unary)
def test_bool_unary_ops(pytorch_op, relax_op):
    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)

    class UnaryOp(Module):
        def forward(self, input):
            return pytorch_op(input)

    @tvm.script.ir_module
    class expected:
        @R.function
        def main(
            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="bool")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 10, 10), dtype="bool") = relax_op(input_1)
                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="bool")) = (lv,)
                R.output(gv)
            return gv

    verify_model(UnaryOp(), example_args, {}, expected)


def test_sqrt_integer_input():
    """Test that sqrt operation works with integer tensors by auto-converting to float."""
    example_args = (torch.tensor([[4, 9, 16, 25]], dtype=torch.int64),)

    class SqrtIntModel(Module):
        def forward(self, input):
            return torch.sqrt(input)

    @tvm.script.ir_module
    class expected_int64:
        @R.function
        def main(
            input_1: R.Tensor((1, 4), dtype="int64")
        ) -> R.Tuple(R.Tensor((1, 4), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 4), dtype="float32") = R.astype(input_1, dtype="float32")
                lv1: R.Tensor((1, 4), dtype="float32") = R.sqrt(lv)
                gv: R.Tuple(R.Tensor((1, 4), dtype="float32")) = (lv1,)
                R.output(gv)
            return gv

    verify_model(SqrtIntModel(), example_args, {}, expected_int64)

    example_args_int32 = (torch.tensor([[1, 4, 9]], dtype=torch.int32),)

    @tvm.script.ir_module
    class expected_int32:
        @R.function
        def main(
            input_1: R.Tensor((1, 3), dtype="int32")
        ) -> R.Tuple(R.Tensor((1, 3), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 3), dtype="float32") = R.astype(input_1, dtype="float32")
                lv1: R.Tensor((1, 3), dtype="float32") = R.sqrt(lv)
                gv: R.Tuple(R.Tensor((1, 3), dtype="float32")) = (lv1,)
                R.output(gv)
            return gv

    verify_model(SqrtIntModel(), example_args_int32, {}, expected_int32)


def test_extended_unary_ops():
    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)

    # celu
    class Celu1(Module):
        def __init__(self):
            super().__init__()
            self.celu = torch.nn.CELU()

        def forward(self, input):
            return self.celu(input)

    class Celu2(Module):
        def forward(self, input):
            return torch.nn.functional.celu(input)

    # alpha * min(0, exp(x / alpha) - 1) + max(0, x)
    @tvm.script.ir_module
    class expected_celu:
        @R.function
        def main(
            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1)
                lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(
                    lv, R.const(1.0, "float32")
                )
                lv2: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater(
                    input_1, R.const(0.0, "float32")
                )
                lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.where(lv2, input_1, lv1)
                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv3,)
                R.output(gv)
            return gv

    verify_model(Celu1(), example_args, {}, expected_celu)
    verify_model(Celu2(), example_args, {}, expected_celu)

    # clamp
    class Clamp(Module):
        def forward(self, input):
            return torch.clamp(input, min=0.1, max=0.5)

    @tvm.script.ir_module
    class expected_clamp:
        @R.function
        def main(
            input: R.Tensor((1, 3, 10, 10), dtype="float32"),
        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(
                    input,
                    R.prim_value(T.float64(0.10000000000000001)),
                    R.prim_value(T.float64(0.5)),
                )
                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    verify_model(Clamp(), example_args, {}, expected_clamp)

    class ClampMinOnly(Module):
        def forward(self, input):
            return torch.clamp(input, min=0.5, max=None)

    @tvm.script.ir_module
    class expected_clamp_min_only:
        @R.function
        def main(
            input: R.Tensor((1, 3, 10, 10), dtype="float32"),
        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(
                    input, R.prim_value(T.float64(0.5)), R.prim_value(T.float64("inf"))
                )
                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    verify_model(ClampMinOnly(), example_args, {}, expected_clamp_min_only)

    class ClampTensors(Module):
        def forward(self, input):
            return torch.clamp(input, min=input, max=input)

    @tvm.script.ir_module
    class expected_clamp_tensors:
        @R.function
        def main(
            input: R.Tensor((1, 3, 10, 10), dtype="float32"),
        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.broadcast_to(
                    input, R.shape([1, 3, 10, 10])
                )
                lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.maximum(input, lv)
                lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.broadcast_to(
                    input, R.shape([1, 3, 10, 10])
                )
                lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.minimum(lv1, lv2)
                lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(
                    lv3, R.prim_value(T.float64("-inf")), R.prim_value(T.float64("inf"))
                )
                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv4,)
                R.output(gv)
            return gv

    verify_model(ClampTensors(), example_args, {}, expected_clamp_tensors)

    # dropout

    class Dropout1(Module):
        def __init__(self):
            super().__init__()
            self.dropout = torch.nn.Dropout(0.5)

        def forward(self, input):
            return self.dropout(input)

    class Dropout2(Module):
        def forward(self, input):
            return torch.dropout(input, 0.5, train=True)

    class Dropout3(Module):
        def forward(self, input):
            return torch.ops.aten.dropout_(input, 0.5, train=True)

    @tvm.script.ir_module
    class expected_dropout_for_1_2:
        @R.function
        def main(
            input: R.Tensor((1, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
            # block 0
            with R.dataflow():
                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (input,)
                R.output(gv)
            return gv

    @tvm.script.ir_module
    class expected_dropout_for_3:
        @R.function
        def main(
            input: R.Tensor((1, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(
            R.Tensor((1, 3, 10, 10), dtype="float32"), R.Tensor((1, 3, 10, 10), dtype="float32")
        ):
            # block 0
            with R.dataflow():
                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.zeros(
                    R.shape([1, 3, 10, 10]), dtype="float32"
                )
                lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(
                    lv, R.const(0.5, "float32")
                )
                lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(input, lv1)
                gv: R.Tuple(
                    R.Tensor((1, 3, 10, 10), dtype="float32"),
                    R.Tensor((1, 3, 10, 10), dtype="float32"),
                ) = (lv2, lv2)
                R.output(gv)
            return gv

    verify_model(Dropout1(), example_args, {}, expected_dropout_for_1_2)
    verify_model(Dropout2(), example_args, {}, expected_dropout_for_1_2)
    verify_model(Dropout3(), example_args, {}, expected_dropout_for_3)

    # elu
    class Elu(Module):
        def __init__(self):
            super().__init__()
            self.elu = torch.nn.ELU()

        def forward(self, input):
            return self.elu(input)

    class Elu2(Module):
        def forward(self, input):
            return torch.nn.functional.elu(input)

    @tvm.script.ir_module
    class expected_elu:
        @R.function
        def main(
            input: R.Tensor((1, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater(
                    input, R.const(0.0, "float32")
                )
                lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(
                    input, R.const(1.0, "float32")
                )
                lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(
                    input, R.const(1.0, "float32")
                )
                lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(lv2)
                lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(
                    lv3, R.const(1.0, "float32")
                )
                lv5: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(
                    lv4, R.const(1.0, "float32")
                )
                lv6: R.Tensor((1, 3, 10, 10), dtype="float32") = R.where(lv, lv1, lv5)
                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv6,)
                R.output(gv)
            return gv

    verify_model(Elu(), example_args, {}, expected_elu)
    verify_model(Elu2(), example_args, {}, expected_elu)

    # hardsigmoid
    class Hardsigmoid(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.hs = torch.nn.Hardsigmoid()

        def forward(self, input):
            return self.hs(input)

    class Hardsigmoid2(torch.nn.Module):
        def forward(self, input):
            return torch.nn.functional.hardsigmoid(input)

    @tvm.script.ir_module
    class expected_hardsigmoid:
        @R.function
        def main(
            inp_0: R.Tensor((1, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(
                    inp_0, R.const(3.0, "float32")
                )
                lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(
                    lv, R.prim_value(0), R.prim_value(T.float64("inf"))
                )
                lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(
                    lv1, R.prim_value(T.float64("-inf")), R.prim_value(6)
                )
                lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(
                    lv2, R.const(6.0, "float32")
                )
                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv3,)
                R.output(gv)
            return gv

    verify_model(Hardsigmoid(), example_args, {}, expected_hardsigmoid)
    verify_model(Hardsigmoid2(), example_args, {}, expected_hardsigmoid)

    # hardwish
    class Hardswish(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.hs = torch.nn.Hardswish()

        def forward(self, input):
            return self.hs(input)

    class Hardswish2(torch.nn.Module):
        def forward(self, input):
            return torch.nn.functional.hardswish(input)

    class Hardswish3(torch.nn.Module):
        def forward(self, input):
            return torch.ops.aten.hardswish_(input)

    @tvm.script.ir_module
    class expected_hardswish_for_1_2:
        @R.function
        def main(
            inp_0: R.Tensor((1, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(
                    inp_0, R.const(3.0, "float32")
                )
                lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(
                    lv, R.prim_value(0), R.prim_value(T.float64("inf"))
                )
                lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(
                    lv1, R.prim_value(T.float64("-inf")), R.prim_value(6)
                )
                lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(inp_0, lv2)
                lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(
                    lv3, R.const(6.0, "float32")
                )
                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv4,)
                R.output(gv)
            return gv

    @tvm.script.ir_module
    class expected_hardswish_for_3:
        @R.function
        def main(
            input: R.Tensor((1, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(
            R.Tensor((1, 3, 10, 10), dtype="float32"), R.Tensor((1, 3, 10, 10), dtype="float32")
        ):
            with R.dataflow():
                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(
                    input, R.const(3.0, "float32")
                )
                lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(
                    lv, R.prim_value(0), R.prim_value(T.float64("inf"))
                )
                lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(
                    lv1, R.prim_value(T.float64("-inf")), R.prim_value(6)
                )
                lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(input, lv2)
                lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(
                    lv3, R.const(6.0, "float32")
                )
                gv: R.Tuple(
                    R.Tensor((1, 3, 10, 10), dtype="float32"),
                    R.Tensor((1, 3, 10, 10), dtype="float32"),
                ) = (lv4, lv4)
                R.output(gv)
            return gv

    verify_model(Hardswish(), example_args, {}, expected_hardswish_for_1_2)
    verify_model(Hardswish2(), example_args, {}, expected_hardswish_for_1_2)
    verify_model(Hardswish3(), example_args, {}, expected_hardswish_for_3)

    # isfinite
    class IsFinite(Module):
        def forward(self, input):
            return torch.isfinite(input)

    @tvm.script.ir_module
    class expected_isfinite:
        @R.function
        def main(
            input: R.Tensor((1, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="bool")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.abs(input)
                lv1: R.Tensor((1, 3, 10, 10), dtype="bool") = R.not_equal(
                    lv, R.const(float("inf"), "float32")
                )
                lv2: R.Tensor((1, 3, 10, 10), dtype="bool") = R.equal(input, input)
                lv3: R.Tensor((1, 3, 10, 10), dtype="bool") = R.multiply(lv2, lv1)
                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="bool")) = (lv3,)
                R.output(gv)
            return gv

    verify_model(IsFinite(), example_args, {}, expected_isfinite)

    # log2
    class Log2(Module):
        def forward(self, x):
            return torch.log2(x)

    @tvm.script.ir_module
    class Expected_log2:
        @R.function
        def main(
            inp_0: R.Tensor((1, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.log(inp_0)
                lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(
                    lv, R.const(0.69314718246459961, "float32")
                )
                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv1,)
                R.output(gv)
            return gv

    verify_model(Log2(), example_args, {}, Expected_log2)

    # log10
    class Log10(Module):
        def forward(self, x):
            return torch.log10(x)

    @tvm.script.ir_module
    class Expected_log10:
        @R.function
        def main(
            inp_0: R.Tensor((1, 3, 10, 10), dtype="float32"),
        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.log(inp_0)
                lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(
                    lv, R.const(2.302585092994046, "float32")
                )
                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv1,)
                R.output(gv)
            return gv

    verify_model(Log10(), example_args, {}, Expected_log10)

    # log1p
    class Log1p(Module):
        def forward(self, x):
            return torch.log1p(x)

    @tvm.script.ir_module
    class Expected_log1p:
        @R.function
        def main(
            inp_0: R.Tensor((1, 3, 10, 10), dtype="float32"),
        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.log(
                    R.add(inp_0, R.const(1, "float32"))
                )
                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    verify_model(Log1p(), example_args, {}, Expected_log1p)

    # reciprocal
    class Reciprocal(Module):
        def forward(self, input):
            return torch.reciprocal(input)

    @tvm.script.ir_module
    class expected_reciprocal:
        @R.function
        def main(
            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(
                    R.const(1.0, "float32"), input_1
                )
                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    verify_model(Reciprocal(), example_args, {}, expected_reciprocal)

    # Returns the maximum value of all elements in the input tensor.
    class MaxModel(Module):
        def forward(self, input):
            return torch.max(input)

    @tvm.script.ir_module
    class expected_max:
        @R.function
        def main(
            input: R.Tensor((1, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((), dtype="float32") = R.max(input, axis=None, keepdims=False)
                gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    verify_model(MaxModel(), example_args, {}, expected_max)

    # Returns the minimum value of all elements in the input tensor.
    class MinModel(Module):
        def forward(self, input):
            return torch.min(input)

    @tvm.script.ir_module
    class expected_min:
        @R.function
        def main(
            input: R.Tensor((1, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((), dtype="float32") = R.min(input, axis=None, keepdims=False)
                gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    verify_model(MinModel(), example_args, {}, expected_min)

    # relu6
    class ReLU6_1(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.relu6 = torch.nn.ReLU6()

        def forward(self, x):
            return self.relu6(x)

    class ReLU6_2(torch.nn.Module):
        def forward(self, x):
            return torch.nn.functional.relu6(x)

    class ReLU6_3(torch.nn.Module):
        def forward(self, x):
            return torch.ops.aten.relu6_(x)

    @tvm.script.ir_module
    class expected_relu6_1:
        @R.function
        def main(
            x: R.Tensor((1, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(
                    x, R.prim_value(T.float64(0.0)), R.prim_value(T.float64(6.0))
                )
                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    @tvm.script.ir_module
    class expected_relu6_2:
        @R.function
        def main(
            x: R.Tensor((1, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu6(x)
                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    @tvm.script.ir_module
    class expected_relu6_3:
        @R.function
        def main(
            x: R.Tensor((1, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(
            R.Tensor((1, 3, 10, 10), dtype="float32"), R.Tensor((1, 3, 10, 10), dtype="float32")
        ):
            with R.dataflow():
                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(
                    x, R.prim_value(0), R.prim_value(6)
                )
                gv: R.Tuple(
                    R.Tensor((1, 3, 10, 10), dtype="float32"),
                    R.Tensor((1, 3, 10, 10), dtype="float32"),
                ) = (lv, lv)
                R.output(gv)
            return gv

    verify_model(ReLU6_1(), example_args, {}, expected_relu6_1)
    verify_model(ReLU6_2(), example_args, {}, expected_relu6_2)
    verify_model(ReLU6_3(), example_args, {}, expected_relu6_3)

    # selu
    class SELU(Module):
        def forward(self, input):
            return torch.nn.functional.selu(input)

    @tvm.script.ir_module
    class expected_selu:
        @R.function
        def main(
            input: R.Tensor((1, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater(
                    input, R.const(0.0, "float32")
                )
                lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(
                    input, R.const(1.0507010221481323, "float32")
                )
                lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(
                    input, R.const(1.0, "float32")
                )
                lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(lv2)
                lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(
                    lv3, R.const(1.0, "float32")
                )
                lv5: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(
                    lv4, R.const(1.7580993175506592, "float32")
                )
                lv6: R.Tensor((1, 3, 10, 10), dtype="float32") = R.where(lv, lv1, lv5)
                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv6,)
                R.output(gv)
            return gv

    verify_model(SELU(), example_args, {}, expected_selu)

    # silu
    class SiLU(Module):
        def forward(self, input):
            return torch.nn.functional.silu(input)

    @tvm.script.ir_module
    class expected_silu:
        @R.function
        def main(
            input: R.Tensor((1, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sigmoid(input)
                lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(input, lv)
                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv1,)
                R.output(gv)
            return gv

    verify_model(SiLU(), example_args, {}, expected_silu)

    # silu_
    class SiLU_(Module):
        def forward(self, input):
            return torch.ops.aten.silu_(input)

    @tvm.script.ir_module
    class expected_silu_:
        @R.function
        def main(
            input: R.Tensor((1, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(
            R.Tensor((1, 3, 10, 10), dtype="float32"), R.Tensor((1, 3, 10, 10), dtype="float32")
        ):
            with R.dataflow():
                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sigmoid(input)
                lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(input, lv)
                gv: R.Tuple(
                    R.Tensor((1, 3, 10, 10), dtype="float32"),
                    R.Tensor((1, 3, 10, 10), dtype="float32"),
                ) = (
                    lv1,
                    lv1,
                )
                R.output(gv)
            return gv

    verify_model(SiLU_(), example_args, {}, expected_silu_)

    # square
    class Square(Module):
        def forward(self, input):
            return torch.square(input)

    @tvm.script.ir_module
    class expected_square:
        @R.function
        def main(
            input: R.Tensor((1, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.power(
                    input, R.const(2.0, "float32")
                )
                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    verify_model(Square(), example_args, {}, expected_square)

    # relu_
    class ReLU_(Module):
        def forward(self, input):
            return torch.relu_(input.clone())

    @tvm.script.ir_module
    class expected_relu_:
        @R.function
        def main(
            input: R.Tensor((1, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu(input)
                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    verify_model(ReLU_(), example_args, {}, expected_relu_)


def test_hardtanh():
    class Hardtanh(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.ht = torch.nn.Hardtanh()

        def forward(self, input):
            return self.ht(input)

    class Hardtanh2(torch.nn.Module):
        def forward(self, input):
            return torch.nn.functional.hardtanh(input)

    class Hardtanh3(torch.nn.Module):
        def forward(self, input):
            return torch.ops.aten.hardtanh_(input)

    @tvm.script.ir_module
    class expected_for_1_2:
        @R.function
        def main(
            inp_0: R.Tensor((1, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(
                    inp_0, R.prim_value(T.float64(-1.0)), R.prim_value(T.float64(1.0))
                )
                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    @tvm.script.ir_module
    class expected_hardtanh_for_3:
        @R.function
        def main(
            inp_0: R.Tensor((1, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(
            R.Tensor((1, 3, 10, 10), dtype="float32"), R.Tensor((1, 3, 10, 10), dtype="float32")
        ):
            with R.dataflow():
                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(
                    inp_0, R.prim_value(T.float64(-1.0)), R.prim_value(T.float64(1.0))
                )
                gv: R.Tuple(
                    R.Tensor((1, 3, 10, 10), dtype="float32"),
                    R.Tensor((1, 3, 10, 10), dtype="float32"),
                ) = (lv, lv)
                R.output(gv)
            return gv

    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
    verify_model(Hardtanh(), example_args, {}, expected_for_1_2)
    verify_model(Hardtanh2(), example_args, {}, expected_for_1_2)
    verify_model(Hardtanh3(), example_args, {}, expected_hardtanh_for_3)


def test_softplus():
    import torch
    from torch.nn import Module

    torch.set_grad_enabled(False)

    class Softplus0(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.softplus = torch.nn.Softplus(1.0, 20.0)

        def forward(self, x):
            return self.softplus(x)

    class Softplus1(Module):
        def forward(self, input):
            return torch.nn.functional.softplus(input, 1.0, 20.0)

    @tvm.script.ir_module
    class expected:
        @R.function
        def main(
            x: R.Tensor((1, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(
                    x, R.const(1.0, "float32")
                )
                lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(lv)
                lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv1, R.const(1.0, "float32"))
                lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.log(lv2)
                lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(
                    lv3, R.const(1.0, "float32")
                )
                lv5: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater(
                    lv, R.const(20.0, "float32")
                )
                lv6: R.Tensor((1, 3, 10, 10), dtype="float32") = R.where(lv5, x, lv4)
                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv6,)
                R.output(gv)
            return gv

    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
    verify_model(Softplus0(), example_args, {}, expected)
    verify_model(Softplus1(), example_args, {}, expected)


def test_leakyrelu():
    import torch
    from torch.nn import Module

    torch.set_grad_enabled(False)

    class LeakyReLU0(Module):
        def __init__(self):
            super().__init__()
            self.leakyrelu = torch.nn.LeakyReLU(0.02)

        def forward(self, input):
            return self.leakyrelu(input)

    class LeakyReLU1(Module):
        def forward(self, input):
            return torch.nn.functional.leaky_relu(input, 0.02)

    class LeakyReLU2(Module):
        def forward(self, input):
            return torch.ops.aten.leaky_relu_(input, 0.02)

    @tvm.script.ir_module
    class expected_for_1_2:
        @R.function
        def main(
            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
            # block 0
            with R.dataflow():
                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.leakyrelu(input_1, alpha=0.02)
                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    @tvm.script.ir_module
    class expected_for_3:
        @R.function
        def main(
            input: R.Tensor((1, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(
            R.Tensor((1, 3, 10, 10), dtype="float32"), R.Tensor((1, 3, 10, 10), dtype="float32")
        ):
            # block 0
            with R.dataflow():
                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.leakyrelu(input, alpha=0.02)
                gv: R.Tuple(
                    R.Tensor((1, 3, 10, 10), dtype="float32"),
                    R.Tensor((1, 3, 10, 10), dtype="float32"),
                ) = (lv, lv)
                R.output(gv)
            return gv

    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
    verify_model(LeakyReLU0(), example_args, {}, expected_for_1_2)
    verify_model(LeakyReLU1(), example_args, {}, expected_for_1_2)
    verify_model(LeakyReLU2(), example_args, {}, expected_for_3)


def test_logaddexp():
    class LogAddExp(Module):
        def forward(self, input1, input2):
            return torch.logaddexp(input1, input2)

    @tvm.script.ir_module
    class expected:
        @R.function
        def main(
            input1: R.Tensor((1, 3, 10, 10), dtype="float32"),
            input2: R.Tensor((1, 3, 10, 10), dtype="float32"),
        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
            # block 0
            with R.dataflow():
                lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater_equal(input1, input2)
                lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.where(lv, input1, input2)
                lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.where(lv, input2, input1)
                lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.abs(input1)
                lv4: R.Tensor((1, 3, 10, 10), dtype="bool") = R.not_equal(
                    lv3, R.const(float("inf"), "float32")
                )
                lv5: R.Tensor((1, 3, 10, 10), dtype="bool") = R.equal(input1, input1)
                lv6: R.Tensor((1, 3, 10, 10), dtype="bool") = R.multiply(lv5, lv4)
                lv7: R.Tensor((1, 3, 10, 10), dtype="bool") = R.logical_not(lv6)
                lv8: R.Tensor((1, 3, 10, 10), dtype="bool") = R.equal(input1, input2)
                lv9: R.Tensor((1, 3, 10, 10), dtype="bool") = R.logical_and(lv7, lv8)
                lv10: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(lv2, lv1)
                lv11: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(lv10)
                lv12: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(
                    lv11, R.const(1.0, "float32")
                )
                lv13: R.Tensor((1, 3, 10, 10), dtype="float32") = R.log(lv12)
                lv14: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv1, lv13)
                lv15: R.Tensor((1, 3, 10, 10), dtype="float32") = R.where(lv9, input1, lv14)
                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv15,)
                R.output(gv)
            return gv

    example_args = (
        torch.randn(1, 3, 10, 10, dtype=torch.float32),
        torch.randn(1, 3, 10, 10, dtype=torch.float32),
    )
    verify_model(LogAddExp(), example_args, {}, expected)


def test_logsoftmax():
    class LogSoftmax(Module):
        def __init__(self):
            super().__init__()
            self.lsm = torch.nn.LogSoftmax(dim=1)

        def forward(self, input):
            return self.lsm(input)

    class LogSoftmax2(Module):
        def forward(self, input):
            return torch.nn.functional.log_softmax(input, dim=1)

    @tvm.script.ir_module
    class expected1:
        @R.function
        def main(
            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
            # block 0
            with R.dataflow():
                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.log_softmax(input_1, axis=1)
                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
    verify_model(LogSoftmax(), example_args, {}, expected1)
    verify_model(LogSoftmax2(), example_args, {}, expected1)


def test_prelu():
    class Prelu1(Module):
        def __init__(self, num_parameters=1, alpha=0.25):
            super().__init__()
            self.prelu = torch.nn.PReLU(num_parameters=num_parameters, init=alpha)

        def forward(self, x):
            return self.prelu(x)

    class Prelu2(torch.nn.Module):
        def __init__(self):
            super(Prelu2, self).__init__()
            self.alpha = torch.nn.Parameter(torch.tensor([0.25]))

        def forward(self, x):
            return torch.nn.functional.prelu(x, self.alpha)

    @tvm.script.ir_module
    class expected:
        @R.function
        def main(
            x: R.Tensor((1, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 1, 1, 1), dtype="float32") = R.reshape(
                    R.const([0.25], dtype="float32"), R.shape([1, 1, 1, 1])
                )
                lv1: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater(x, R.const(0.0, "float32"))
                lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(lv, x)
                lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.where(lv1, x, lv2)
                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv3,)
                R.output(gv)
            return gv

    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
    verify_model(Prelu1(), example_args, {}, expected)
    verify_model(Prelu2(), example_args, {}, expected)


def test_softmax():
    class Softmax(Module):
        def __init__(self):
            super().__init__()
            self.sm = torch.nn.Softmax(dim=1)

        def forward(self, input):
            return self.sm(input)

    class Softmax2(Module):
        def forward(self, input):
            return torch.nn.functional.softmax(input, dim=1)

    @tvm.script.ir_module
    class expected1:
        @R.function
        def main(
            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
            # block 0
            with R.dataflow():
                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.softmax(input_1, axis=1)
                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
    verify_model(Softmax(), example_args, {}, expected1)
    verify_model(Softmax2(), example_args, {}, expected1)


def test_softsign():
    class Softsign(Module):
        def __init__(self):
            super().__init__()
            self.ss = torch.nn.Softsign()

        def forward(self, input):
            return self.ss(input)

    class Softsign2(Module):
        def forward(self, input):
            return torch.nn.functional.softsign(input)

    @tvm.script.ir_module
    class expected_softsign:
        @R.function
        def main(
            input: R.Tensor((1, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
            with R.dataflow():
                abs_val = R.abs(input)
                denom = R.add(abs_val, R.const(1.0, "float32"))
                result = R.divide(input, denom)
                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (result,)
                R.output(gv)
            return gv

    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
    verify_model(Softsign(), example_args, {}, expected_softsign)
    verify_model(Softsign2(), example_args, {}, expected_softsign)


def test_softshrink():
    class Softshrink(Module):
        def __init__(self):
            super().__init__()
            self.softshrink = torch.nn.Softshrink(lambd=0.5)

        def forward(self, input):
            return self.softshrink(input)

    class Softshrink2(Module):
        def forward(self, input):
            return torch.nn.functional.softshrink(input, lambd=0.5)

    @tvm.script.ir_module
    class expected_softshrink:
        @R.function
        def main(
            input: R.Tensor((1, 3, 10, 10), dtype="float32"),
        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.abs(input)
                lv1: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater(lv, R.const(0.5, "float32"))
                lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sign(input)
                lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(
                    lv2, R.const(0.5, "float32")
                )
                lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(input, lv3)
                lv5: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(
                    input, R.const(0.0, "float32")
                )
                lv6: R.Tensor((1, 3, 10, 10), dtype="float32") = R.where(lv1, lv4, lv5)
                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv6,)
                R.output(gv)
            return gv

    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
    verify_model(Softshrink(), example_args, {}, expected_softshrink)
    verify_model(Softshrink2(), example_args, {}, expected_softshrink)


def test_tril_triu():
    example_args = (torch.randn(10, 10, dtype=torch.float32),)

    class Tril(Module):
        def forward(self, input):
            return torch.tril(input, 1)

    @tvm.script.ir_module
    class expected_tril:
        @R.function
        def main(
            input: R.Tensor((10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
            # block 0
            with R.dataflow():
                lv: R.Tensor((10,), dtype="int64") = R.arange(
                    R.prim_value(0), R.prim_value(10), R.prim_value(1), dtype="int64"
                )
                lv1: R.Tensor((1, 10), dtype="int64") = R.expand_dims(lv, axis=[-2])
                lv2: R.Tensor((10,), dtype="int64") = R.arange(
                    R.prim_value(0), R.prim_value(10), R.prim_value(1), dtype="int64"
                )
                lv3: R.Tensor((10, 1), dtype="int64") = R.expand_dims(lv2, axis=[-1])
                lv4: R.Tensor((10, 10), dtype="int64") = R.subtract(lv1, lv3)
                lv5: R.Tensor((10, 10), dtype="bool") = R.less_equal(lv4, R.const(1, "int64"))
                lv6: R.Tensor((), dtype="float32") = R.const(0.0, "float32")
                lv7: R.Tensor((10, 10), dtype="float32") = R.where(lv5, input, lv6)
                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv7,)
                R.output(gv)
            return gv

    verify_model(Tril(), example_args, {}, expected_tril)

    class Triu(Module):
        def forward(self, input):
            return torch.triu(input, 1)

    @tvm.script.ir_module
    class expected_triu:
        @R.function
        def main(
            input: R.Tensor((10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
            # block 0
            with R.dataflow():
                lv: R.Tensor((10,), dtype="int64") = R.arange(
                    R.prim_value(0), R.prim_value(10), R.prim_value(1), dtype="int64"
                )
                lv1: R.Tensor((1, 10), dtype="int64") = R.expand_dims(lv, axis=[-2])
                lv2: R.Tensor((10,), dtype="int64") = R.arange(
                    R.prim_value(0), R.prim_value(10), R.prim_value(1), dtype="int64"
                )
                lv3: R.Tensor((10, 1), dtype="int64") = R.expand_dims(lv2, axis=[-1])
                lv4: R.Tensor((10, 10), dtype="int64") = R.subtract(lv1, lv3)
                lv5: R.Tensor((10, 10), dtype="bool") = R.greater_equal(lv4, R.const(1, "int64"))
                lv6: R.Tensor((), dtype="float32") = R.const(0.0, "float32")
                lv7: R.Tensor((10, 10), dtype="float32") = R.where(lv5, input, lv6)
                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv7,)
                R.output(gv)
            return gv

    verify_model(Triu(), example_args, {}, expected_triu)


operator_binary_1 = [
    (operator.add, R.add),
    (torch.ops.aten.add_, R.add),
    (torch.ops.aten.bitwise_or, R.bitwise_or),
    (torch.ops.aten.bitwise_or_, R.bitwise_or),
    (operator.sub, R.subtract),
    (operator.mul, R.multiply),
    (torch.ops.aten.mul_, R.multiply),
    (operator.truediv, R.divide),
    (operator.floordiv, R.floor_divide),
    (torch.ops.aten.fmod, R.mod),
    (operator.pow, R.power),
    (operator.mod, R.floor_mod),
    (operator.and_, R.bitwise_and),
    (operator.or_, R.bitwise_or),
    (operator.xor, R.bitwise_xor),
]


@pytest.mark.parametrize("op, relax_op", operator_binary_1)
def test_binary1(op, relax_op):
    example_args1 = (
        torch.randn(10, 10, dtype=torch.float32),
        torch.randn(10, 10, dtype=torch.float32),
    )
    example_args2 = (torch.randn(10, 10, dtype=torch.float32),)

    class Binary1(Module):
        def __init__(self, op):
            super().__init__()
            self.op = op

        def forward(self, lhs, rhs):
            return self.op(lhs, rhs)

    @tvm.script.ir_module
    class expected_binary1:
        @R.function
        def main(
            lhs: R.Tensor((10, 10), dtype="float32"),
            rhs: R.Tensor((10, 10), dtype="float32"),
        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((10, 10), dtype="float32") = relax_op(lhs, rhs)
                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    @tvm.script.ir_module
    class expected_binary1_inplace:
        @R.function
        def main(
            lhs: R.Tensor((10, 10), dtype="float32"),
            rhs: R.Tensor((10, 10), dtype="float32"),
        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((10, 10), dtype="float32") = relax_op(lhs, rhs)
                gv: R.Tuple(
                    R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10), dtype="float32")
                ) = (lv, lv)
                R.output(gv)
            return gv

    class Binary2(Module):
        def __init__(self, op):
            super().__init__()
            self.op = op

        def forward(self, lhs):
            return self.op(lhs, 1.0)

    @tvm.script.ir_module
    class expected_binary2:
        @R.function
        def main(
            lhs: R.Tensor((10, 10), dtype="float32"),
        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((10, 10), dtype="float32") = relax_op(lhs, R.const(1.0))
                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    @tvm.script.ir_module
    class expected_binary2_inplace:
        @R.function
        def main(
            lhs: R.Tensor((10, 10), dtype="float32"),
        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((10, 10), dtype="float32") = relax_op(lhs, R.const(1.0))
                gv: R.Tuple(
                    R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10), dtype="float32")
                ) = (lv, lv)
                R.output(gv)
            return gv

    inplace_ops = [
        torch.ops.aten.add_,
        torch.ops.aten.bitwise_or_,
        torch.ops.aten.mul_,
    ]

    expected1 = expected_binary1_inplace if op in inplace_ops else expected_binary1
    expected2 = expected_binary2_inplace if op in inplace_ops else expected_binary2
    verify_model(Binary1(op), example_args1, {}, expected1)
    verify_model(Binary2(op), example_args2, {}, expected2)


operator_binary_2 = [
    (operator.eq, R.equal),
    (operator.ne, R.not_equal),
    (operator.lt, R.less),
    (operator.le, R.less_equal),
    (operator.gt, R.greater),
    (operator.ge, R.greater_equal),
]


@pytest.mark.parametrize("op, relax_op", operator_binary_2)
def test_binary2(op, relax_op):
    example_args1 = (
        torch.randn(10, 10, dtype=torch.float32),
        torch.randn(10, 10, dtype=torch.float32),
    )
    example_args2 = (torch.randn(10, 10, dtype=torch.float32),)

    class Binary1(Module):
        def __init__(self, op):
            super().__init__()
            self.op = op

        def forward(self, lhs, rhs):
            return self.op(lhs, rhs)

    @tvm.script.ir_module
    class expected_binary1:
        @R.function
        def main(
            lhs: R.Tensor((10, 10), dtype="float32"),
            rhs: R.Tensor((10, 10), dtype="float32"),
        ) -> R.Tuple(R.Tensor((10, 10), dtype="bool")):
            with R.dataflow():
                lv: R.Tensor((10, 10), dtype="bool") = relax_op(lhs, rhs)
                gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv,)
                R.output(gv)
            return gv

    class Binary2(Module):
        def __init__(self, op):
            super().__init__()
            self.op = op

        def forward(self, lhs):
            return self.op(lhs, 1.0)

    @tvm.script.ir_module
    class expected_binary2:
        @R.function
        def main(
            lhs: R.Tensor((10, 10), dtype="float32"),
        ) -> R.Tuple(R.Tensor((10, 10), dtype="bool")):
            with R.dataflow():
                lv: R.Tensor((10, 10), dtype="bool") = relax_op(lhs, R.const(1.0))
                gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv,)
                R.output(gv)
            return gv

    verify_model(Binary1(op), example_args1, {}, expected_binary1)
    verify_model(Binary2(op), example_args2, {}, expected_binary2)


def test_binary3():
    example_args1 = (
        torch.randn(10, 10, dtype=torch.float32),
        torch.randn(10, 10, dtype=torch.float32),
    )
    example_args2 = (torch.randn(10, 10, dtype=torch.float32),)

    # Max
    class Max1(Module):
        def forward(self, x, y):
            return torch.max(x, y)

    @I.ir_module
    class expected_max1:
        @R.function
        def main(
            inp_0: R.Tensor((10, 10), dtype="float32"),
            inp_1: R.Tensor((10, 10), dtype="float32"),
        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((10, 10), dtype="float32") = R.maximum(inp_0, inp_1)
                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    verify_model(Max1(), example_args1, {}, expected_max1)

    # Min
    class Min1(Module):
        def forward(self, x, y):
            return torch.min(x, y)

    @I.ir_module
    class expected_min1:
        @R.function
        def main(
            inp_0: R.Tensor((10, 10), dtype="float32"),
            inp_1: R.Tensor((10, 10), dtype="float32"),
        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((10, 10), dtype="float32") = R.minimum(inp_0, inp_1)
                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    verify_model(Min1(), example_args1, {}, expected_min1)

    # RSub
    class RSub1(Module):
        def forward(self, x, y):
            return torch.rsub(x, y)

    class RSub2(Module):
        def forward(self, x):
            return torch.rsub(x, 5.0)

    @tvm.script.ir_module
    class expected_rsub1:
        @R.function
        def main(
            x: R.Tensor((10, 10), dtype="float32"), y: R.Tensor((10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((10, 10), dtype="float32") = R.subtract(y, x)
                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    @tvm.script.ir_module
    class expected_rsub2:
        @R.function
        def main(
            x: R.Tensor((10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((10, 10), dtype="float32") = R.subtract(R.const(5.0, "float32"), x)
                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    verify_model(RSub1(), example_args1, {}, expected_rsub1)
    verify_model(RSub2(), example_args2, {}, expected_rsub2)


# IsIn


def test_isin():
    class IsInModel(torch.nn.Module):
        def forward(self, x, test_elements):
            return torch.isin(x, test_elements)

    @tvm.script.ir_module
    class expected:
        @R.function
        def main(
            x: R.Tensor((10, 10), dtype="float32"), test_elements: R.Tensor((8,), dtype="float32")
        ) -> R.Tuple(R.Tensor((10, 10), dtype="bool")):
            with R.dataflow():
                lv: R.Tensor((10, 10, 1), dtype="float32") = R.reshape(x, R.shape([10, 10, 1]))
                lv1: R.Tensor((10, 10, 8), dtype="bool") = R.equal(lv, test_elements)
                lv2: R.Tensor((10, 10), dtype="bool") = R.max(lv1, axis=[-1], keepdims=False)
                gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv2,)
                R.output(gv)
            return gv

    example_args = (
        torch.randn(10, 10, dtype=torch.float32),
        torch.randn(8, dtype=torch.float32),
    )
    verify_model(IsInModel(), example_args, {}, expected)


def test_div_mode():
    # Case 1: Basic division (no rounding mode)
    class DivModel(torch.nn.Module):
        def forward(self, a, b):
            return torch.div(a, b)

    @tvm.script.ir_module
    class expected_div:
        @R.function
        def main(
            a: R.Tensor((64, 64), dtype="float32"), b: R.Tensor((64,), dtype="float32")
        ) -> R.Tuple(R.Tensor((64, 64), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((64, 64), dtype="float32") = R.divide(a, b)
                gv: R.Tuple(R.Tensor((64, 64), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (
        torch.randn(64, 64, dtype=torch.float32),
        torch.randn(64, dtype=torch.float32),
    )
    verify_model(DivModel(), example_args, {}, expected_div)

    # Case 2: Division with trunc rounding
    class DivTruncModel(torch.nn.Module):
        def forward(self, a, b):
            return torch.div(a, b, rounding_mode="trunc")

    @tvm.script.ir_module
    class expected_div_trunc:
        @R.function
        def main(
            a: R.Tensor((64, 64), dtype="float32"), b: R.Tensor((64,), dtype="float32")
        ) -> R.Tuple(R.Tensor((64, 64), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((64, 64), dtype="float32") = R.divide(a, b)
                lv1: R.Tensor((64, 64), dtype="float32") = R.trunc(lv)
                gv: R.Tuple(R.Tensor((64, 64), dtype="float32")) = (lv1,)
                R.output(gv)
            return gv

    verify_model(DivTruncModel(), example_args, {}, expected_div_trunc)

    # Case 3: Division with floor rounding
    class DivFloorModel(torch.nn.Module):
        def forward(self, a, b):
            return torch.div(a, b, rounding_mode="floor")

    @tvm.script.ir_module
    class expected_div_floor:
        @R.function
        def main(
            a: R.Tensor((64, 64), dtype="float32"), b: R.Tensor((64,), dtype="float32")
        ) -> R.Tuple(R.Tensor((64, 64), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((64, 64), dtype="float32") = R.floor_divide(a, b)
                gv: R.Tuple(R.Tensor((64, 64), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    verify_model(DivFloorModel(), example_args, {}, expected_div_floor)


def test_batchnorm2d():
    class BatchNorm2d1(Module):
        def __init__(self):
            super().__init__()
            self.bn = torch.nn.BatchNorm2d(3)

        def forward(self, input):
            return self.bn(input)

    @tvm.script.ir_module
    class expected1:
        @R.function
        def main(
            input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
            w1: R.Tensor((3,), dtype="float32"),
            w2: R.Tensor((3,), dtype="float32"),
            w3: R.Tensor((3,), dtype="float32"),
            w4: R.Tensor((3,), dtype="float32"),
        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
            # block 0
            with R.dataflow():
                lv: R.Tuple(
                    R.Tensor((1, 3, 10, 10), dtype="float32"),
                    R.Tensor((3,), dtype="float32"),
                    R.Tensor((3,), dtype="float32"),
                ) = R.nn.batch_norm(
                    input_1,
                    w1,
                    w2,
                    w3,
                    w4,
                    axis=1,
                    epsilon=1e-05,
                    center=True,
                    scale=True,
                    momentum=0.1,
                    training=False,
                )
                lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0]
                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv1,)
                R.output(gv)
            return gv

    class BatchNorm2dCustom(Module):
        def __init__(self):
            super().__init__()
            self.bn = torch.nn.BatchNorm2d(3, eps=0.001, momentum=0.01)

        def forward(self, input):
            return self.bn(input)

    @tvm.script.ir_module
    class expected2:
        @R.function
        def main(
            input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
            w1: R.Tensor((3,), dtype="float32"),
            w2: R.Tensor((3,), dtype="float32"),
            w3: R.Tensor((3,), dtype="float32"),
            w4: R.Tensor((3,), dtype="float32"),
        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
            with R.dataflow():
                lv: R.Tuple(
                    R.Tensor((1, 3, 10, 10), dtype="float32"),
                    R.Tensor((3,), dtype="float32"),
                    R.Tensor((3,), dtype="float32"),
                ) = R.nn.batch_norm(
                    input_1,
                    w1,
                    w2,
                    w3,
                    w4,
                    axis=1,
                    epsilon=0.001,
                    center=True,
                    scale=True,
                    momentum=0.01,
                    training=False,
                )
                lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0]
                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv1,)
                R.output(gv)
            return gv

    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)

    model_1 = BatchNorm2d1().eval()
    binding_1 = {
        "w1": model_1.bn.weight.detach().numpy(),
        "w2": model_1.bn.bias.detach().numpy(),
        "w3": model_1.bn.running_mean.detach().numpy(),
        "w4": model_1.bn.running_var.detach().numpy(),
    }
    verify_model(model_1, example_args, binding_1, expected1)

    model_2 = BatchNorm2dCustom().eval()
    binding_2 = {
        "w1": model_2.bn.weight.detach().numpy(),
        "w2": model_2.bn.bias.detach().numpy(),
        "w3": model_2.bn.running_mean.detach().numpy(),
        "w4": model_2.bn.running_var.detach().numpy(),
    }
    verify_model(model_2, example_args, binding_2, expected2)


def test_adaptive_avgpool1d():
    class AdaptiveAvgPool1d0(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.pool = torch.nn.AdaptiveAvgPool1d(output_size=5)

        def forward(self, input):
            return self.pool(input)

    class AdaptiveAvgPool1d1(torch.nn.Module):
        def forward(self, input):
            return torch.nn.functional.adaptive_avg_pool1d(input, output_size=5)

    @tvm.script.ir_module
    class expected1:
        @R.function
        def main(
            input_1: R.Tensor((1, 3, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 5), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 1, 10), dtype="float32") = R.expand_dims(input_1, axis=[-2])
                lv1: R.Tensor((1, 3, 1, 5), dtype="float32") = R.nn.adaptive_avg_pool2d(
                    lv, output_size=[1, 5], layout="NCHW"
                )
                lv2: R.Tensor((1, 3, 5), dtype="float32") = R.squeeze(lv1, axis=[-2])
                gv: R.Tuple(R.Tensor((1, 3, 5), dtype="float32")) = (lv2,)
                R.output(gv)
            return gv

    example_args = (torch.randn(1, 3, 10, dtype=torch.float32),)
    verify_model(AdaptiveAvgPool1d0(), example_args, {}, expected1)
    verify_model(AdaptiveAvgPool1d1(), example_args, {}, expected1)


def test_adaptive_avgpool2d():
    class AdaptiveAvgPool2d0(Module):
        def __init__(self):
            super().__init__()
            self.pool = torch.nn.AdaptiveAvgPool2d([10, 10])

        def forward(self, input):
            return self.pool(input)

    class AdaptiveAvgPool2d1(Module):
        def forward(self, input):
            return torch.nn.functional.adaptive_avg_pool2d(input, [10, 10])

    @tvm.script.ir_module
    class expected1:
        @R.function
        def main(
            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
            # block 0
            with R.dataflow():
                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.adaptive_avg_pool2d(
                    input_1, output_size=[10, 10], layout="NCHW", out_layout="NCHW"
                )
                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
    verify_model(AdaptiveAvgPool2d0(), example_args, {}, expected1)
    verify_model(AdaptiveAvgPool2d1(), example_args, {}, expected1)


def test_adaptive_avgpool3d():
    class AdaptiveAvgPool3d0(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.pool = torch.nn.AdaptiveAvgPool3d([4, 4, 4])

        def forward(self, input):
            return self.pool(input)

    class AdaptiveAvgPool3d1(torch.nn.Module):
        def forward(self, input):
            return torch.nn.functional.adaptive_avg_pool3d(input, [4, 4, 4])

    @tvm.script.ir_module
    class expected1:
        @R.function
        def main(
            input_1: R.Tensor((1, 3, 8, 8, 8), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 4, 4, 4), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 4, 4, 4), dtype="float32") = R.nn.adaptive_avg_pool3d(
                    input_1, output_size=[4, 4, 4], layout="NCDHW", out_layout="NCDHW"
                )
                gv: R.Tuple(R.Tensor((1, 3, 4, 4, 4), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (torch.randn(1, 3, 8, 8, 8, dtype=torch.float32),)
    verify_model(AdaptiveAvgPool3d0(), example_args, {}, expected1)
    verify_model(AdaptiveAvgPool3d1(), example_args, {}, expected1)


def test_addmm():
    class Addmm1(Module):
        def __init__(self):
            super().__init__()

        def forward(self, x1, x2, x3):
            return torch.addmm(x1, x2, x3)

    class Addmm2(Module):
        def __init__(self):
            super().__init__()

        def forward(self, x1, x2, x3):
            return torch.addmm(x1, x2, x3, beta=0.8, alpha=0.5)

    @tvm.script.ir_module
    class expected1:
        @R.function
        def main(
            x1: R.Tensor((10, 10), dtype="float32"),
            x2: R.Tensor((10, 10), dtype="float32"),
            x3: R.Tensor((10, 10), dtype="float32"),
        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
            # block 0
            with R.dataflow():
                lv: R.Tensor((10, 10), dtype="float32") = R.matmul(x2, x3, out_dtype="float32")
                lv1: R.Tensor((10, 10), dtype="float32") = R.add(x1, lv)
                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv1,)
                R.output(gv)
            return gv

    @tvm.script.ir_module
    class expected2:
        @R.function
        def main(
            x1: R.Tensor((10, 10), dtype="float32"),
            x2: R.Tensor((10, 10), dtype="float32"),
            x3: R.Tensor((10, 10), dtype="float32"),
        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
            # block 0
            with R.dataflow():
                lv: R.Tensor((10, 10), dtype="float32") = R.matmul(x2, x3, out_dtype="float32")
                lv1: R.Tensor((10, 10), dtype="float32") = R.multiply(lv, R.const(0.5, "float32"))
                lv2: R.Tensor((10, 10), dtype="float32") = R.multiply(x1, R.const(0.8, "float32"))
                lv3: R.Tensor((10, 10), dtype="float32") = R.add(lv2, lv1)
                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv3,)
                R.output(gv)
            return gv

    example_args = (
        torch.randn(10, 10, dtype=torch.float32),
        torch.randn(10, 10, dtype=torch.float32),
        torch.randn(10, 10, dtype=torch.float32),
    )

    verify_model(Addmm1(), example_args, {}, expected1)
    verify_model(Addmm2(), example_args, {}, expected2)


def test_avg_pool1d():
    class AvgPool1d1(Module):
        def __init__(self):
            super().__init__()
            self.pool = torch.nn.AvgPool1d(kernel_size=1)

        def forward(self, input):
            return self.pool(input)

    @tvm.script.ir_module
    class expected1:
        @R.function
        def main(
            input: R.Tensor((1, 3, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 10), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 1, 10), dtype="float32") = R.expand_dims(input, axis=[-2])
                lv1: R.Tensor((1, 3, 1, 10), dtype="float32") = R.nn.avg_pool2d(
                    lv,
                    pool_size=[1, 1],
                    strides=[1, 1],
                    dilation=[1, 1],
                    padding=[0, 0, 0, 0],
                    ceil_mode=False,
                    count_include_pad=True,
                    layout="NCHW",
                    out_layout="NCHW",
                )
                lv2: R.Tensor((1, 3, 10), dtype="float32") = R.squeeze(lv1, axis=[-2])
                gv: R.Tuple(R.Tensor((1, 3, 10), dtype="float32")) = (lv2,)
                R.output(gv)
            return gv

    class AvgPool1d2(Module):
        def __init__(self):
            super().__init__()
            self.pool = torch.nn.AvgPool1d(kernel_size=3, stride=2, padding=1, ceil_mode=True)

        def forward(self, input):
            return self.pool(input)

    class AvgPool1d3(Module):
        def forward(self, input):
            return torch.nn.functional.avg_pool1d(
                input, kernel_size=3, stride=2, padding=1, ceil_mode=True
            )

    @tvm.script.ir_module
    class expected2:
        @R.function
        def main(
            input: R.Tensor((1, 3, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 6), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 1, 10), dtype="float32") = R.expand_dims(input, axis=[-2])
                lv1: R.Tensor((1, 3, 1, 6), dtype="float32") = R.nn.avg_pool2d(
                    lv,
                    pool_size=[1, 3],
                    strides=[1, 2],
                    dilation=[1, 1],
                    padding=[0, 1, 0, 1],
                    ceil_mode=True,
                    count_include_pad=True,
                    layout="NCHW",
                    out_layout="NCHW",
                )
                lv2: R.Tensor((1, 3, 6), dtype="float32") = R.squeeze(lv1, axis=[-2])
                gv: R.Tuple(R.Tensor((1, 3, 6), dtype="float32")) = (lv2,)
                R.output(gv)
            return gv

    class AvgPool1d4(Module):
        def forward(self, input):
            return torch.nn.functional.avg_pool1d(input, kernel_size=2, stride=2, padding=0)

    @tvm.script.ir_module
    class expected3:
        @R.function
        def main(
            input: R.Tensor((1, 3, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 5), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 1, 10), dtype="float32") = R.expand_dims(input, axis=[-2])
                lv1: R.Tensor((1, 3, 1, 5), dtype="float32") = R.nn.avg_pool2d(
                    lv,
                    pool_size=[1, 2],
                    strides=[1, 2],
                    dilation=[1, 1],
                    padding=[0, 0, 0, 0],
                    ceil_mode=False,
                    count_include_pad=True,
                    layout="NCHW",
                    out_layout="NCHW",
                )
                lv2: R.Tensor((1, 3, 5), dtype="float32") = R.squeeze(lv1, axis=[-2])
                gv: R.Tuple(R.Tensor((1, 3, 5), dtype="float32")) = (lv2,)
                R.output(gv)
            return gv

    example_args = (torch.randn(1, 3, 10, dtype=torch.float32),)
    verify_model(AvgPool1d1(), example_args, {}, expected1)
    verify_model(AvgPool1d2(), example_args, {}, expected2)
    verify_model(AvgPool1d3(), example_args, {}, expected2)
    verify_model(AvgPool1d4(), example_args, {}, expected3)


def test_avg_pool2d():
    class AvgPool2d1(Module):
        def __init__(self):
            super().__init__()
            self.pool = torch.nn.AvgPool2d(kernel_size=[1, 1])

        def forward(self, input):
            return self.pool(input)

    @tvm.script.ir_module
    class expected1:
        @R.function
        def main(
            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
            # block 0
            with R.dataflow():
                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.avg_pool2d(
                    input_1,
                    pool_size=[1, 1],
                    strides=[1, 1],
                    dilation=[1, 1],
                    padding=[0, 0, 0, 0],
                    count_include_pad=True,
                    layout="NCHW",
                    out_layout="NCHW",
                )
                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    class AvgPool2d2(Module):
        def __init__(self):
            super().__init__()
            self.pool = torch.nn.AvgPool2d(kernel_size=[4, 4], stride=2, padding=2, ceil_mode=True)

        def forward(self, input):
            return self.pool(input)

    class AvgPool2d3(Module):
        def forward(self, input):
            return torch.nn.functional.avg_pool2d(
                input, kernel_size=[4, 4], stride=2, padding=2, ceil_mode=True
            )

    @tvm.script.ir_module
    class expected2:
        @R.function
        def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")):
            with R.dataflow():
                lv = R.nn.avg_pool2d(
                    input_1,
                    pool_size=[4, 4],
                    strides=[2, 2],
                    dilation=[1, 1],
                    padding=[2, 2, 2, 2],
                    ceil_mode=True,
                    count_include_pad=True,
                    layout="NCHW",
                    out_layout="NCHW",
                )
                gv = (lv,)
                R.output(gv)
            return gv

    class AvgPool2d4(Module):
        def forward(self, input):
            return torch.nn.functional.avg_pool2d(input, kernel_size=[2, 1], divisor_override=2)

    @tvm.script.ir_module
    class expected4:
        @R.function
        def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")):
            with R.dataflow():
                lv = R.nn.avg_pool2d(
                    input_1,
                    pool_size=[2, 1],
                    strides=[2, 1],
                    dilation=[1, 1],
                    padding=[0, 0, 0, 0],
                    ceil_mode=False,
                    count_include_pad=True,
                    layout="NCHW",
                    out_layout="NCHW",
                )
                gv = (lv,)
                R.output(gv)
            return gv

    class AvgPool2d5(Module):
        def forward(self, input):
            return torch.nn.functional.avg_pool2d(
                input, kernel_size=[2, 1], divisor_override=2, count_include_pad=False
            )

    @tvm.script.ir_module
    class expected5:
        @R.function
        def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")):
            with R.dataflow():
                lv = R.nn.avg_pool2d(
                    input_1,
                    pool_size=[2, 1],
                    strides=[2, 1],
                    dilation=[1, 1],
                    padding=[0, 0, 0, 0],
                    ceil_mode=False,
                    count_include_pad=False,
                    layout="NCHW",
                    out_layout="NCHW",
                )
                gv = (lv,)
                R.output(gv)
            return gv

    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
    verify_model(AvgPool2d1(), example_args, {}, expected1)
    verify_model(AvgPool2d2(), example_args, {}, expected2)
    verify_model(AvgPool2d3(), example_args, {}, expected2)
    verify_model(AvgPool2d4(), example_args, {}, expected4)
    verify_model(AvgPool2d5(), example_args, {}, expected5)


def test_avg_pool3d():
    class AvgPool3d1(Module):
        def __init__(self):
            super().__init__()
            self.pool = torch.nn.AvgPool3d(kernel_size=1)

        def forward(self, input):
            return self.pool(input)

    @tvm.script.ir_module
    class expected1:
        @R.function
        def main(
            input_1: R.Tensor((1, 3, 8, 8, 8), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 8, 8, 8), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 8, 8, 8), dtype="float32") = R.nn.avg_pool3d(
                    input_1,
                    pool_size=[1, 1, 1],
                    strides=[1, 1, 1],
                    dilation=[1, 1, 1],
                    padding=[0, 0, 0, 0, 0, 0],
                    ceil_mode=False,
                    count_include_pad=True,
                    layout="NCDHW",
                    out_layout="NCDHW",
                )
                gv: R.Tuple(R.Tensor((1, 3, 8, 8, 8), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    class AvgPool3d2(Module):
        def __init__(self):
            super().__init__()
            self.pool = torch.nn.AvgPool3d(kernel_size=3, stride=2, padding=1, ceil_mode=True)

        def forward(self, input):
            return self.pool(input)

    class AvgPool3d3(Module):
        def forward(self, input):
            return torch.nn.functional.avg_pool3d(
                input, kernel_size=3, stride=2, padding=1, ceil_mode=True
            )

    @tvm.script.ir_module
    class expected2:
        @R.function
        def main(input_1: R.Tensor((1, 3, 8, 8, 8), dtype="float32")):
            with R.dataflow():
                lv = R.nn.avg_pool3d(
                    input_1,
                    pool_size=[3, 3, 3],
                    strides=[2, 2, 2],
                    dilation=[1, 1, 1],
                    padding=[1, 1, 1, 1, 1, 1],
                    ceil_mode=True,
                    count_include_pad=True,
                    layout="NCDHW",
                    out_layout="NCDHW",
                )
                gv = (lv,)
                R.output(gv)
            return gv

    class AvgPool3d4(Module):
        def forward(self, input):
            return torch.nn.functional.avg_pool3d(input, kernel_size=[2, 1, 2], stride=[2, 1, 2])

    @tvm.script.ir_module
    class expected3:
        @R.function
        def main(input_1: R.Tensor((1, 3, 8, 8, 8), dtype="float32")):
            with R.dataflow():
                lv = R.nn.avg_pool3d(
                    input_1,
                    pool_size=[2, 1, 2],
                    strides=[2, 1, 2],
                    dilation=[1, 1, 1],
                    padding=[0, 0, 0, 0, 0, 0],
                    ceil_mode=False,
                    count_include_pad=True,
                    layout="NCDHW",
                    out_layout="NCDHW",
                )
                gv = (lv,)
                R.output(gv)
            return gv

    example_args = (torch.randn(1, 3, 8, 8, 8, dtype=torch.float32),)
    verify_model(AvgPool3d1(), example_args, {}, expected1)
    verify_model(AvgPool3d2(), example_args, {}, expected2)
    verify_model(AvgPool3d3(), example_args, {}, expected2)
    verify_model(AvgPool3d4(), example_args, {}, expected3)


def test_baddbmm():
    class BAddBMM1(Module):
        def __init__(self):
            super().__init__()

        def forward(self, c, x, y):
            return torch.baddbmm(c, x, y)

    @tvm.script.ir_module
    class Expected1:
        @R.function
        def main(
            inp_0: R.Tensor((4, 128, 512), dtype="float32"),
            inp_1: R.Tensor((4, 128, 256), dtype="float32"),
            inp_2: R.Tensor((4, 256, 512), dtype="float32"),
        ) -> R.Tuple(R.Tensor((4, 128, 512), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(
                    inp_1, inp_2, out_dtype="float32"
                )
                lv1: R.Tensor((4, 128, 512), dtype="float32") = R.add(inp_0, lv)
                gv: R.Tuple(R.Tensor((4, 128, 512), dtype="float32")) = (lv1,)
                R.output(gv)
            return gv

    class BAddBMM2(Module):
        def __init__(self):
            super().__init__()

        def forward(self, c, x, y):
            return torch.baddbmm(c, x, y, alpha=2, beta=0)

    @tvm.script.ir_module
    class Expected2:
        @R.function
        def main(
            inp_0: R.Tensor((4, 128, 512), dtype="float32"),
            inp_1: R.Tensor((4, 128, 256), dtype="float32"),
            inp_2: R.Tensor((4, 256, 512), dtype="float32"),
        ) -> R.Tuple(R.Tensor((4, 128, 512), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(
                    inp_1, inp_2, out_dtype="float32"
                )
                lv1: R.Tensor((4, 128, 512), dtype="float32") = R.multiply(
                    lv, R.const(2, "float32")
                )
                gv: R.Tuple(R.Tensor((4, 128, 512), dtype="float32")) = (lv1,)
                R.output(gv)
            return gv

    class BAddBMM3(Module):
        def __init__(self):
            super().__init__()

        def forward(self, c, x, y):
            return torch.baddbmm(c, x, y, alpha=2, beta=3)

    @tvm.script.ir_module
    class Expected3:
        @R.function
        def main(
            inp_0: R.Tensor((4, 128, 512), dtype="float32"),
            inp_1: R.Tensor((4, 128, 256), dtype="float32"),
            inp_2: R.Tensor((4, 256, 512), dtype="float32"),
        ) -> R.Tuple(R.Tensor((4, 128, 512), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(
                    inp_1, inp_2, out_dtype="float32"
                )
                lv1: R.Tensor((4, 128, 512), dtype="float32") = R.multiply(
                    lv, R.const(2, "float32")
                )
                lv2: R.Tensor((4, 128, 512), dtype="float32") = R.multiply(
                    inp_0, R.const(3, "float32")
                )
                lv3: R.Tensor((4, 128, 512), dtype="float32") = R.add(lv2, lv1)
                gv: R.Tuple(R.Tensor((4, 128, 512), dtype="float32")) = (lv3,)
                R.output(gv)
            return gv

    example_args = (
        torch.randn(4, 128, 512, dtype=torch.float32),
        torch.randn(4, 128, 256, dtype=torch.float32),
        torch.randn(4, 256, 512, dtype=torch.float32),
    )
    verify_model(
        BAddBMM1(),
        example_args,
        {},
        Expected1,
        run_ep_decomposition=True,
    )

    verify_model(
        BAddBMM2(),
        example_args,
        {},
        Expected2,
        run_ep_decomposition=True,
    )

    verify_model(
        BAddBMM3(),
        example_args,
        {},
        Expected3,
        run_ep_decomposition=True,
    )


def test_bmm():
    class BMM(Module):
        def __init__(self):
            super().__init__()

        def forward(self, x, y):
            return torch.bmm(x, y)

    @tvm.script.ir_module
    class Expected:
        @R.function
        def main(
            input_1: R.Tensor((4, 128, 256), dtype="float32"),
            input_2: R.Tensor((4, 256, 512), dtype="float32"),
        ) -> R.Tuple(R.Tensor((4, 128, 512), dtype="float32")):
            # block 0
            with R.dataflow():
                lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(
                    input_1, input_2, out_dtype="float32"
                )
                gv: R.Tuple(R.Tensor((4, 128, 512), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (
        torch.randn(4, 128, 256, dtype=torch.float32),
        torch.randn(4, 256, 512, dtype=torch.float32),
    )
    verify_model(
        BMM(),
        example_args,
        {},
        Expected,
        run_ep_decomposition=True,
    )


def test_conv_transpose1d():
    class ConvTranspose1d1(Module):
        def __init__(self):
            super().__init__()
            self.conv = torch.nn.ConvTranspose1d(6, 6, 3, bias=True)

        def forward(self, input):
            return self.conv(input)

    class ConvTranspose1d1Func(Module):
        def __init__(self):
            super().__init__()
            self.weight = torch.randn(size=[6, 6, 3])
            self.bias = torch.randn(size=[6])

        def forward(self, input):
            return torch.nn.functional.conv_transpose1d(input, self.weight, self.bias)

    @tvm.script.ir_module
    class expected1:
        @R.function
        def main(
            input_1: R.Tensor((1, 6, 4), dtype="float32"),
            w1: R.Tensor((6, 6, 3), dtype="float32"),
            w2: R.Tensor((6,), dtype="float32"),
        ) -> R.Tuple(R.Tensor((1, 6, 6), dtype="float32")):
            # block 0
            with R.dataflow():
                lv1: R.Tensor((1, 6, 6), dtype="float32") = R.nn.conv1d_transpose(
                    input_1,
                    w1,
                    strides=[1],
                    padding=[0, 0],
                    output_padding=[0],
                    dilation=[1],
                    data_layout="NCW",
                    kernel_layout="IOW",
                    out_layout="NCW",
                    out_dtype="float32",
                )
                lv2: R.Tensor((1, 6, 1)) = R.reshape(w2, [1, 6, 1])
                lv3: R.Tensor((1, 6, 6), dtype="float32") = R.add(lv1, lv2)
                gv: R.Tuple(R.Tensor((1, 6, 6), dtype="float32")) = (lv3,)
                R.output(gv)
            return gv

    class ConvTranspose1d2(Module):
        def __init__(self):
            super().__init__()
            self.conv = torch.nn.ConvTranspose1d(6, 6, 3, bias=False)

        def forward(self, input):
            return self.conv(input)

    @tvm.script.ir_module
    class expected2:
        @R.function
        def main(
            input_1: R.Tensor((1, 6, 4), dtype="float32"),
            w1: R.Tensor((6, 6, 3), dtype="float32"),
        ) -> R.Tuple(R.Tensor((1, 6, 6), dtype="float32")):
            # block 0
            with R.dataflow():
                lv1: R.Tensor((1, 6, 6), dtype="float32") = R.nn.conv1d_transpose(
                    input_1,
                    w1,
                    strides=[1],
                    padding=[0, 0],
                    output_padding=[0],
                    dilation=[1],
                    data_layout="NCW",
                    kernel_layout="IOW",
                    out_layout="NCW",
                    out_dtype="float32",
                )
                gv: R.Tuple(R.Tensor((1, 6, 6), dtype="float32")) = (lv1,)
                R.output(gv)
            return gv

    example_args = (torch.randn(1, 6, 4, dtype=torch.float32),)

    model = ConvTranspose1d1()
    binding = {"w1": model.conv.weight.detach().numpy(), "w2": model.conv.bias.detach().numpy()}
    verify_model(model, example_args, binding, expected1)

    model = ConvTranspose1d1Func()
    binding = {"w1": model.weight.detach().numpy(), "w2": model.bias.detach().numpy()}
    verify_model(model, example_args, binding, expected1)

    model = ConvTranspose1d2()
    binding = {"w1": model.conv.weight.detach().numpy()}
    verify_model(model, example_args, binding, expected2)


def test_conv_transpose2d():
    class ConvTranspose2d1(Module):
        def __init__(self):
            super().__init__()
            self.conv = torch.nn.ConvTranspose2d(3, 3, 7, bias=True)

        def forward(self, input):
            return self.conv(input)

    class ConvTranspose2d1Func(Module):
        def __init__(self):
            super().__init__()
            self.weight = torch.randn(size=[3, 3, 7, 7])
            self.bias = torch.randn(size=[3])

        def forward(self, input):
            return torch.nn.functional.conv_transpose2d(input, self.weight, self.bias)

    @tvm.script.ir_module
    class expected1:
        @R.function
        def main(
            input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
            w1: R.Tensor((3, 3, 7, 7), dtype="float32"),
            w2: R.Tensor((3,), dtype="float32"),
        ) -> R.Tuple(R.Tensor((1, 3, 16, 16), dtype="float32")):
            # block 0
            with R.dataflow():
                lv1: R.Tensor((1, 3, 16, 16), dtype="float32") = R.nn.conv2d_transpose(
                    input_1,
                    w1,
                    strides=[1, 1],
                    padding=[0, 0, 0, 0],
                    output_padding=[0, 0],
                    dilation=[1, 1],
                    data_layout="NCHW",
                    kernel_layout="IOHW",
                    out_layout="NCHW",
                    out_dtype="float32",
                )
                lv2: R.Tensor((1, 3, 1, 1)) = R.reshape(w2, [1, 3, 1, 1])
                lv3: R.Tensor((1, 3, 16, 16), dtype="float32") = R.add(lv1, lv2)
                gv: R.Tuple(R.Tensor((1, 3, 16, 16), dtype="float32")) = (lv3,)
                R.output(gv)
            return gv

    class ConvTranspose2d2(Module):
        def __init__(self):
            super().__init__()
            self.conv = torch.nn.ConvTranspose2d(3, 3, 7, bias=False)

        def forward(self, input):
            return self.conv(input)

    @tvm.script.ir_module
    class expected2:
        @R.function
        def main(
            input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
            w1: R.Tensor((3, 3, 7, 7), dtype="float32"),
        ) -> R.Tuple(R.Tensor((1, 3, 16, 16), dtype="float32")):
            # block 0
            with R.dataflow():
                lv1: R.Tensor((1, 3, 16, 16), dtype="float32") = R.nn.conv2d_transpose(
                    input_1,
                    w1,
                    strides=[1, 1],
                    padding=[0, 0, 0, 0],
                    output_padding=[0, 0],
                    dilation=[1, 1],
                    data_layout="NCHW",
                    kernel_layout="IOHW",
                    out_layout="NCHW",
                    out_dtype="float32",
                )
                gv: R.Tuple(R.Tensor((1, 3, 16, 16), dtype="float32")) = (lv1,)
                R.output(gv)
            return gv

    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)

    model = ConvTranspose2d1()
    binding = {"w1": model.conv.weight.detach().numpy(), "w2": model.conv.bias.detach().numpy()}
    verify_model(model, example_args, binding, expected1)

    model = ConvTranspose2d1Func()
    binding = {"w1": model.weight.detach().numpy(), "w2": model.bias.detach().numpy()}
    verify_model(model, example_args, binding, expected1)

    model = ConvTranspose2d2()
    binding = {"w1": model.conv.weight.detach().numpy()}
    verify_model(model, example_args, binding, expected2)


def test_conv1d():
    class Conv1D1(Module):
        def __init__(self):
            super().__init__()
            self.conv = torch.nn.Conv1d(3, 6, 7, bias=True)

        def forward(self, input):
            return self.conv(input)

    class Conv1D1Func(Module):
        def __init__(self):
            super().__init__()
            self.weight = torch.randn(size=[6, 3, 7])
            self.bias = torch.randn(size=[6])

        def forward(self, input):
            return torch.nn.functional.conv1d(input, self.weight, self.bias)

    @tvm.script.ir_module
    class expected1:
        @R.function
        def main(
            w1: R.Tensor((6, 3, 7), dtype="float32"),
            w2: R.Tensor((6,), dtype="float32"),
            input_1: R.Tensor((1, 3, 10), dtype="float32"),
        ) -> R.Tuple(R.Tensor((1, 6, 4), dtype="float32")):
            # block 0
            with R.dataflow():
                lv1: R.Tensor((1, 6, 4), dtype="float32") = R.nn.conv1d(
                    input_1,
                    w1,
                    strides=[1],
                    padding=[0, 0],
                    dilation=[1],
                    data_layout="NCW",
                    kernel_layout="OIW",
                    out_layout="NCW",
                    out_dtype="float32",
                )
                lv2: R.Tensor((1, 6, 1), dtype="float32") = R.reshape(w2, [1, 6, 1])
                lv3: R.Tensor((1, 6, 4), dtype="float32") = R.add(lv1, lv2)
                gv: R.Tuple(R.Tensor((1, 6, 4), dtype="float32")) = (lv3,)
                R.output(gv)
            return gv

    class Conv1D2(Module):
        def __init__(self):
            super().__init__()
            self.conv = torch.nn.Conv1d(3, 6, 7, bias=False)

        def forward(self, input):
            return self.conv(input)

    @tvm.script.ir_module
    class expected2:
        @R.function
        def main(
            w1: R.Tensor((6, 3, 7), dtype="float32"),
            input_1: R.Tensor((1, 3, 10), dtype="float32"),
        ) -> R.Tuple(R.Tensor((1, 6, 4), dtype="float32")):
            # block 0
            with R.dataflow():
                lv1: R.Tensor((1, 6, 4), dtype="float32") = R.nn.conv1d(
                    input_1,
                    w1,
                    strides=[1],
                    padding=[0, 0],
                    dilation=[1],
                    data_layout="NCW",
                    kernel_layout="OIW",
                    out_layout="NCW",
                    out_dtype="float32",
                )
                gv: R.Tuple(R.Tensor((1, 6, 4), dtype="float32")) = (lv1,)
                R.output(gv)
            return gv

    example_args = (torch.randn(1, 3, 10, dtype=torch.float32),)

    model = Conv1D1()
    binding = {"w1": model.conv.weight.detach().numpy(), "w2": model.conv.bias.detach().numpy()}
    verify_model(model, example_args, binding, expected1)

    model = Conv1D1Func()
    binding = {"w1": model.weight.detach().numpy(), "w2": model.bias.detach().numpy()}
    verify_model(model, example_args, binding, expected1)

    model = Conv1D2()
    binding = {"w1": model.conv.weight.detach().numpy()}
    verify_model(model, example_args, binding, expected2)


def test_conv2d():
    class Conv2D1(Module):
        def __init__(self):
            super().__init__()
            self.conv = torch.nn.Conv2d(3, 6, 7, bias=True)

        def forward(self, input):
            return self.conv(input)

    class Conv2D1Func(Module):
        def __init__(self):
            super().__init__()
            self.weight = torch.randn(size=[6, 3, 7, 7])
            self.bias = torch.randn(size=[6])

        def forward(self, input):
            return torch.nn.functional.conv2d(input, self.weight, self.bias)

    @tvm.script.ir_module
    class expected1:
        @R.function
        def main(
            input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
            w1: R.Tensor((6, 3, 7, 7), dtype="float32"),
            w2: R.Tensor((6,), dtype="float32"),
        ) -> R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")):
            # block 0
            with R.dataflow():
                lv1: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.conv2d(
                    input_1,
                    w1,
                    strides=[1, 1],
                    padding=[0, 0, 0, 0],
                    dilation=[1, 1],
                    data_layout="NCHW",
                    kernel_layout="OIHW",
                    out_layout="NCHW",
                    out_dtype="float32",
                )
                lv2: R.Tensor((1, 6, 1, 1)) = R.reshape(w2, [1, 6, 1, 1])
                lv3: R.Tensor((1, 6, 4, 4), dtype="float32") = R.add(lv1, lv2)
                gv: R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")) = (lv3,)
                R.output(gv)
            return gv

    class Conv2D2(Module):
        def __init__(self):
            super().__init__()
            self.conv = torch.nn.Conv2d(3, 6, 7, bias=False)

        def forward(self, input):
            return self.conv(input)

    @tvm.script.ir_module
    class expected2:
        @R.function
        def main(
            input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
            w1: R.Tensor((6, 3, 7, 7), dtype="float32"),
        ) -> R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")):
            # block 0
            with R.dataflow():
                lv1: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.conv2d(
                    input_1,
                    w1,
                    strides=[1, 1],
                    padding=[0, 0, 0, 0],
                    dilation=[1, 1],
                    data_layout="NCHW",
                    kernel_layout="OIHW",
                    out_layout="NCHW",
                    out_dtype="float32",
                )
                gv: R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")) = (lv1,)
                R.output(gv)
            return gv

    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)

    model = Conv2D1()
    binding = {"w1": model.conv.weight.detach().numpy(), "w2": model.conv.bias.detach().numpy()}
    verify_model(model, example_args, binding, expected1)

    model = Conv2D1Func()
    binding = {"w1": model.weight.numpy(), "w2": model.bias.numpy()}
    verify_model(model, example_args, binding, expected1)

    model = Conv2D2()
    binding = {"w1": model.conv.weight.detach().numpy()}
    verify_model(model, example_args, binding, expected2)


def test_conv3d():
    class Conv3D1(Module):
        def __init__(self):
            super().__init__()
            self.conv = torch.nn.Conv3d(3, 6, 7, bias=True)

        def forward(self, input):
            return self.conv(input)

    class Conv3D1Func(Module):
        def __init__(self):
            super().__init__()
            self.weight = torch.randn(size=[6, 3, 7, 7, 7])
            self.bias = torch.randn(size=[6])

        def forward(self, input):
            return torch.nn.functional.conv3d(input, self.weight, self.bias)

    @tvm.script.ir_module
    class expected1:
        @R.function
        def main(
            input_1: R.Tensor((1, 3, 10, 10, 10), dtype="float32"),
            w1: R.Tensor((6, 3, 7, 7, 7), dtype="float32"),
            w2: R.Tensor((6,), dtype="float32"),
        ) -> R.Tuple(R.Tensor((1, 6, 4, 4, 4), dtype="float32")):
            # block 0
            with R.dataflow():
                lv1: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = R.nn.conv3d(
                    input_1,
                    w1,
                    strides=[1],
                    padding=[0, 0, 0],
                    dilation=[1],
                    data_layout="NCDHW",
                    kernel_layout="OIDHW",
                    out_layout="NCDHW",
                    out_dtype="float32",
                )
                lv2: R.Tensor((1, 6, 1, 1, 1)) = R.reshape(w2, [1, 6, 1, 1, 1])
                lv3: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = R.add(lv1, lv2)
                gv: R.Tuple(R.Tensor((1, 6, 4, 4, 4), dtype="float32")) = (lv3,)
                R.output(gv)
            return gv

    class Conv3D2(Module):
        def __init__(self):
            super().__init__()
            self.conv = torch.nn.Conv3d(3, 6, 7, bias=False)

        def forward(self, input):
            return self.conv(input)

    @tvm.script.ir_module
    class expected2:
        @R.function
        def main(
            input_1: R.Tensor((1, 3, 10, 10, 10), dtype="float32"),
            w1: R.Tensor((6, 3, 7, 7, 7), dtype="float32"),
        ) -> R.Tuple(R.Tensor((1, 6, 4, 4, 4), dtype="float32")):
            # block 0
            with R.dataflow():
                lv1: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = R.nn.conv3d(
                    input_1,
                    w1,
                    strides=[1],
                    padding=[0, 0, 0],
                    dilation=[1],
                    data_layout="NCDHW",
                    kernel_layout="OIDHW",
                    out_layout="NCDHW",
                    out_dtype="float32",
                )
                gv: R.Tuple(R.Tensor((1, 6, 4, 4, 4), dtype="float32")) = (lv1,)
                R.output(gv)
            return gv

    example_args = (torch.randn(1, 3, 10, 10, 10, dtype=torch.float32),)

    model = Conv3D1()
    binding = {"w1": model.conv.weight.detach().numpy(), "w2": model.conv.bias.detach().numpy()}
    verify_model(model, example_args, binding, expected1)

    model = Conv3D1Func()
    binding = {"w1": model.weight.detach().numpy(), "w2": model.bias.detach().numpy()}
    verify_model(model, example_args, binding, expected1)

    model = Conv3D2()
    binding = {"w1": model.conv.weight.detach().numpy()}
    verify_model(model, example_args, binding, expected2)


def test_pad():
    class PadModel(torch.nn.Module):
        def __init__(self, pad, mode="constant", value=0.0):
            super().__init__()
            self.pad = pad
            self.mode = mode
            self.value = value

        def forward(self, x):
            if self.mode == "constant":
                return torch.nn.functional.pad(x, self.pad, mode=self.mode, value=self.value)
            else:
                return torch.nn.functional.pad(x, self.pad, mode=self.mode)

    @tvm.script.ir_module
    class expected_constant:
        @R.function
        def main(
            x: R.Tensor((1, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad(
                    x,
                    pad_width=[0, 0, 0, 0, 2, 2, 1, 1],
                    pad_mode="constant",
                    pad_value=0.0,
                )
                gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    @tvm.script.ir_module
    class expected_reflect:
        @R.function
        def main(
            x: R.Tensor((1, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((14,), dtype="int64") = R.arange(
                    R.prim_value(-2), R.prim_value(12), R.prim_value(1), dtype="int64"
                )
                lv1: R.Tensor((14,), dtype="int64") = R.abs(lv)
                lv2: R.Tensor((14,), dtype="int64") = R.subtract(R.const(9, "int64"), lv1)
                lv3: R.Tensor((14,), dtype="int64") = R.abs(lv2)
                lv4: R.Tensor((14,), dtype="int64") = R.subtract(R.const(9, "int64"), lv3)
                lv5: R.Tensor((1, 3, 14, 10), dtype="float32") = R.take(x, lv4, axis=2, mode="fast")
                lv6: R.Tensor((12,), dtype="int64") = R.arange(
                    R.prim_value(-1), R.prim_value(11), R.prim_value(1), dtype="int64"
                )
                lv7: R.Tensor((12,), dtype="int64") = R.abs(lv6)
                lv8: R.Tensor((12,), dtype="int64") = R.subtract(R.const(9, "int64"), lv7)
                lv9: R.Tensor((12,), dtype="int64") = R.abs(lv8)
                lv10: R.Tensor((12,), dtype="int64") = R.subtract(R.const(9, "int64"), lv9)
                lv11: R.Tensor((1, 3, 14, 12), dtype="float32") = R.take(
                    lv5, lv10, axis=3, mode="fast"
                )
                gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv11,)
                R.output(gv)
            return gv

    @tvm.script.ir_module
    class expected_replicate:
        @R.function
        def main(
            x: R.Tensor((1, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((14,), dtype="int64") = R.arange(
                    R.prim_value(-2), R.prim_value(12), R.prim_value(1), dtype="int64"
                )
                lv1: R.Tensor((14,), dtype="int64") = R.clip(lv, R.prim_value(0), R.prim_value(9))
                lv2: R.Tensor((1, 3, 14, 10), dtype="float32") = R.take(x, lv1, axis=2, mode="fast")
                lv3: R.Tensor((12,), dtype="int64") = R.arange(
                    R.prim_value(-1), R.prim_value(11), R.prim_value(1), dtype="int64"
                )
                lv4: R.Tensor((12,), dtype="int64") = R.clip(lv3, R.prim_value(0), R.prim_value(9))
                lv5: R.Tensor((1, 3, 14, 12), dtype="float32") = R.take(
                    lv2, lv4, axis=3, mode="fast"
                )
                gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv5,)
                R.output(gv)
            return gv

    @tvm.script.ir_module
    class expected_circular:
        @R.function
        def main(
            x: R.Tensor((1, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.zeros(
                    R.shape([1, 3, 14, 12]), dtype="float32"
                )

                lv1: R.Tensor((1, 3, 14, 10), dtype="float32") = R.strided_slice(
                    lv,
                    (R.prim_value(3),),
                    (R.prim_value(1),),
                    (R.prim_value(11),),
                    (R.prim_value(1),),
                    assume_inbound=False,
                )

                lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.strided_slice(
                    x,
                    (R.prim_value(3),),
                    (R.prim_value(0),),
                    (R.prim_value(10),),
                    (R.prim_value(1),),
                    assume_inbound=False,
                )

                lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.strided_slice(
                    lv1,
                    (R.prim_value(2),),
                    (R.prim_value(2),),
                    (R.prim_value(12),),
                    (R.prim_value(1),),
                    assume_inbound=False,
                )

                lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.strided_slice(
                    lv2,
                    (R.prim_value(2),),
                    (R.prim_value(0),),
                    (R.prim_value(10),),
                    (R.prim_value(1),),
                    assume_inbound=False,
                )

                lv5: R.Tensor((1, 3, 10, 10), dtype="float32") = R.broadcast_to(
                    lv4, R.shape([1, 3, 10, 10])
                )

                lv6: R.Tensor((1, 3, 14, 10), dtype="float32") = R.strided_slice(
                    lv,
                    (R.prim_value(3),),
                    (R.prim_value(1),),
                    (R.prim_value(11),),
                    (R.prim_value(1),),
                    assume_inbound=False,
                )

                lv7: R.Tensor((1, 3, 14, 10), dtype="float32") = R.slice_scatter(
                    lv6, lv5, R.prim_value(2), R.prim_value(12), R.prim_value(1), axis=2
                )

                lv8: R.Tensor((1, 3, 14, 12), dtype="float32") = R.slice_scatter(
                    lv, lv7, R.prim_value(1), R.prim_value(11), R.prim_value(1), axis=3
                )

                lv9: R.Tensor((1, 3, 14, 1), dtype="float32") = R.strided_slice(
                    lv8,
                    (R.prim_value(3),),
                    (R.prim_value(0),),
                    (R.prim_value(1),),
                    (R.prim_value(1),),
                    assume_inbound=False,
                )

                lv10: R.Tensor((1, 3, 14, 1), dtype="float32") = R.strided_slice(
                    lv8,
                    (R.prim_value(3),),
                    (R.prim_value(10),),
                    (R.prim_value(11),),
                    (R.prim_value(1),),
                    assume_inbound=False,
                )

                lv11: R.Tensor((1, 3, 14, 1), dtype="float32") = R.broadcast_to(
                    lv10, R.shape([1, 3, 14, 1])
                )

                lv12: R.Tensor((1, 3, 14, 12), dtype="float32") = R.slice_scatter(
                    lv8, lv11, R.prim_value(0), R.prim_value(1), R.prim_value(1), axis=3
                )

                lv13: R.Tensor((1, 3, 14, 1), dtype="float32") = R.strided_slice(
                    lv12,
                    (R.prim_value(3),),
                    (R.prim_value(11),),
                    (R.prim_value(12),),
                    (R.prim_value(1),),
                    assume_inbound=False,
                )

                lv14: R.Tensor((1, 3, 14, 1), dtype="float32") = R.strided_slice(
                    lv12,
                    (R.prim_value(3),),
                    (R.prim_value(1),),
                    (R.prim_value(2),),
                    (R.prim_value(1),),
                    assume_inbound=False,
                )

                lv15: R.Tensor((1, 3, 14, 1), dtype="float32") = R.broadcast_to(
                    lv14, R.shape([1, 3, 14, 1])
                )
                lv16: R.Tensor((1, 3, 14, 12), dtype="float32") = R.slice_scatter(
                    lv12, lv15, R.prim_value(11), R.prim_value(12), R.prim_value(1), axis=3
                )

                lv17: R.Tensor((1, 3, 2, 12), dtype="float32") = R.strided_slice(
                    lv16,
                    (R.prim_value(2),),
                    (R.prim_value(0),),
                    (R.prim_value(2),),
                    (R.prim_value(1),),
                    assume_inbound=False,
                )

                lv18: R.Tensor((1, 3, 2, 12), dtype="float32") = R.strided_slice(
                    lv16,
                    (R.prim_value(2),),
                    (R.prim_value(10),),
                    (R.prim_value(12),),
                    (R.prim_value(1),),
                    assume_inbound=False,
                )

                lv19: R.Tensor((1, 3, 2, 12), dtype="float32") = R.broadcast_to(
                    lv18, R.shape([1, 3, 2, 12])
                )

                lv20: R.Tensor((1, 3, 14, 12), dtype="float32") = R.slice_scatter(
                    lv16, lv19, R.prim_value(0), R.prim_value(2), R.prim_value(1), axis=2
                )
                lv21: R.Tensor((1, 3, 2, 12), dtype="float32") = R.strided_slice(
                    lv20,
                    (R.prim_value(2),),
                    (R.prim_value(12),),
                    (R.prim_value(14),),
                    (R.prim_value(1),),
                    assume_inbound=False,
                )

                lv22: R.Tensor((1, 3, 2, 12), dtype="float32") = R.strided_slice(
                    lv20,
                    (R.prim_value(2),),
                    (R.prim_value(2),),
                    (R.prim_value(4),),
                    (R.prim_value(1),),
                    assume_inbound=False,
                )

                lv23: R.Tensor((1, 3, 2, 12), dtype="float32") = R.broadcast_to(
                    lv22, R.shape([1, 3, 2, 12])
                )

                lv24: R.Tensor((1, 3, 14, 12), dtype="float32") = R.slice_scatter(
                    lv20, lv23, R.prim_value(12), R.prim_value(14), R.prim_value(1), axis=2
                )
                gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv24,)
                R.output(gv)
            return gv

    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
    verify_model(PadModel(pad=[1, 1, 2, 2]), example_args, {}, expected_constant)
    verify_model(
        PadModel(pad=[1, 1, 2, 2], mode="reflect"),
        example_args,
        {},
        expected_reflect,
        run_ep_decomposition=True,
    )
    verify_model(
        PadModel(pad=[1, 1, 2, 2], mode="replicate"),
        example_args,
        {},
        expected_replicate,
        run_ep_decomposition=True,
    )
    verify_model(
        PadModel(pad=[1, 1, 2, 2], mode="circular"),
        example_args,
        {},
        expected_circular,
        run_ep_decomposition=True,
    )


def test_pixel_shuffle():
    class PixelShuffle1(torch.nn.Module):
        def __init__(self, upscale_factor=2):
            super().__init__()
            self.pixel_shuffle = torch.nn.PixelShuffle(upscale_factor)

        def forward(self, x):
            return self.pixel_shuffle(x)

    class PixelShuffle2(torch.nn.Module):
        def __init__(self, upscale_factor=2):
            super().__init__()
            self.upscale_factor = upscale_factor

        def forward(self, x):
            return torch.nn.functional.pixel_shuffle(x, self.upscale_factor)

    @tvm.script.ir_module
    class expected:
        @R.function
        def main(
            x: R.Tensor((1, 8, 10, 15), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 2, 20, 30), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 2, 2, 2, 10, 15), dtype="float32") = R.reshape(
                    x, R.shape([1, 2, 2, 2, 10, 15])
                )
                lv1: R.Tensor((1, 2, 10, 2, 15, 2), dtype="float32") = R.permute_dims(
                    lv, axes=[0, 1, 4, 2, 5, 3]
                )
                lv2: R.Tensor((1, 2, 20, 30), dtype="float32") = R.reshape(
                    lv1, R.shape([1, 2, 20, 30])
                )
                gv: R.Tuple(R.Tensor((1, 2, 20, 30), dtype="float32")) = (lv2,)
                R.output(gv)
            return gv

    example_args = (torch.randn(1, 8, 10, 15, dtype=torch.float32),)
    verify_model(PixelShuffle1(upscale_factor=2), example_args, {}, expected)
    verify_model(PixelShuffle2(upscale_factor=2), example_args, {}, expected)


def test_einsum():
    class Einsum1(Module):
        def __init__(self):
            super().__init__()

        def forward(self, x):
            return torch.einsum("ii", x)

    class Einsum2(Module):
        def __init__(self):
            super().__init__()

        def forward(self, x, y):
            return torch.einsum("i,j->ij", x, y)

    @tvm.script.ir_module
    class Expected1:
        @R.function
        def main(
            inp_0: R.Tensor((4, 4), dtype="float32")
        ) -> R.Tuple(R.Tensor((), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((), dtype="float32") = R.einsum((inp_0,), subscripts="ii")
                gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    @tvm.script.ir_module
    class Expected2:
        @R.function
        def main(
            inp_0: R.Tensor((5,), dtype="float32"), inp_1: R.Tensor((4,), dtype="float32")
        ) -> R.Tuple(R.Tensor((5, 4), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((5, 4), dtype="float32") = R.einsum(
                    (inp_0, inp_1), subscripts="i,j->ij"
                )
                gv: R.Tuple(R.Tensor((5, 4), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (torch.randn(4, 4, dtype=torch.float32),)
    verify_model(Einsum1(), example_args, {}, Expected1, run_ep_decomposition=False)

    example_args = (torch.randn(5, dtype=torch.float32), torch.randn(4, dtype=torch.float32))
    verify_model(Einsum2(), example_args, {}, Expected2, run_ep_decomposition=False)


def test_outer():
    class Outer(torch.nn.Module):
        def forward(self, x, y):
            return torch.outer(x, y)

    @tvm.script.ir_module
    class expected:
        @R.function
        def main(
            x: R.Tensor((3,), dtype="float32"), y: R.Tensor((4,), dtype="float32")
        ) -> R.Tuple(R.Tensor((3, 4), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((3, 1), dtype="float32") = R.reshape(x, R.shape([3, 1]))
                lv1: R.Tensor((3, 4), dtype="float32") = R.multiply(lv, y)
                gv: R.Tuple(R.Tensor((3, 4), dtype="float32")) = (lv1,)
                R.output(gv)
            return gv

    example_args = (
        torch.randn(3, dtype=torch.float32),
        torch.randn(4, dtype=torch.float32),
    )
    verify_model(Outer(), example_args, {}, expected)


def test_embedding():
    class Embedding(Module):
        def __init__(self):
            super().__init__()
            self.embedding = torch.nn.Embedding(10, 3)

        def forward(self, input):
            return self.embedding(input)

    @tvm.script.ir_module
    class expected1:
        @R.function
        def main(
            input_1: R.Tensor((4,), dtype="int64"), w1: R.Tensor((10, 3), dtype="float32")
        ) -> R.Tuple(R.Tensor((4, 3), dtype="float32")):
            # block 0
            with R.dataflow():
                lv: R.Tensor((4,), dtype="int32") = R.astype(input_1, dtype="int32")
                lv1: R.Tensor((4, 3), dtype="float32") = R.take(w1, lv, axis=0)
                gv: R.Tuple(R.Tensor((4, 3), dtype="float32")) = (lv1,)
                R.output(gv)
            return gv

    example_args = (torch.randint(low=-int(1e5), high=int(1e5), size=(4,), dtype=torch.int64),)

    model = Embedding()
    binding = {"w1": model.embedding.weight.detach().numpy()}
    verify_model(model, example_args, binding, expected1)


def test_groupnorm():
    import torch
    from torch.nn import Module

    torch.set_grad_enabled(False)
    torch.random.manual_seed(0)

    class GroupNorm(Module):
        def __init__(self):
            super().__init__()
            self.gn = torch.nn.GroupNorm(3, 3)

        def forward(self, input):
            return self.gn(input)

    @tvm.script.ir_module
    class expected1:
        @R.function
        def main(
            input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
            w1: R.Tensor((3,), dtype="float32"),
            w2: R.Tensor((3,), dtype="float32"),
        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.group_norm(
                    input_1,
                    w1,
                    w2,
                    num_groups=3,
                    channel_axis=1,
                    axes=[2, 3],
                    epsilon=1.0000000000000001e-05,
                    center=True,
                    scale=True,
                )
                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)

    model = GroupNorm()
    binding = {
        "w1": model.gn.weight.detach().numpy(),
        "w2": model.gn.bias.detach().numpy(),
    }
    verify_model(model, example_args, binding, expected1)


def test_instancenorm2d():
    torch.set_grad_enabled(False)
    torch.random.manual_seed(0)

    class InstanceNorm2d(Module):
        def __init__(self):
            super().__init__()
            self.gn = torch.nn.InstanceNorm2d(3)

        def forward(self, input):
            return self.gn(input)

    @tvm.script.ir_module
    class expected1:
        @R.function
        def main(
            input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
            w1: R.Tensor((3,), dtype="float32"),
            w2: R.Tensor((3,), dtype="float32"),
        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.instance_norm(
                    input_1,
                    w1,
                    w2,
                    channel_axis=1,
                    axes=[2, 3],
                    epsilon=1e-05,
                    center=True,
                    scale=True,
                )
                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)

    model = InstanceNorm2d()
    binding = {
        "w1": torch.ones(3).detach().numpy(),
        "w2": torch.zeros(3).detach().numpy(),
    }
    verify_model(model, example_args, binding, expected1)


def test_layernorm():
    class LayerNorm(Module):
        def __init__(self):
            super().__init__()
            self.ln = torch.nn.LayerNorm((10, 10))

        def forward(self, input):
            return self.ln(input)

    @tvm.script.ir_module
    class expected1:
        @R.function
        def main(
            input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
            w1: R.Tensor((10, 10), dtype="float32"),
            w2: R.Tensor((10, 10), dtype="float32"),
        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
            # block 0
            with R.dataflow():
                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.layer_norm(
                    input_1,
                    w1,
                    w2,
                    axes=[-2, -1],
                    epsilon=1e-05,
                    center=True,
                    scale=True,
                )
                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)

    model = LayerNorm()
    binding = {
        "w1": model.ln.weight.detach().numpy(),
        "w2": model.ln.bias.detach().numpy(),
    }
    verify_model(LayerNorm(), example_args, binding, expected1)


def test_linear():
    class Dense1(Module):
        def __init__(self):
            super().__init__()
            self.linear = torch.nn.Linear(10, 7, bias=True)

        def forward(self, input):
            return self.linear(input)

    class Dense1Func(Module):
        def __init__(self):
            super().__init__()
            self.weight = torch.randn(size=[7, 10])
            self.bias = torch.randn(size=[7])

        def forward(self, input):
            return torch.nn.functional.linear(input, self.weight, self.bias)

    @tvm.script.ir_module
    class expected1:
        @R.function
        def main(
            w1: R.Tensor((7, 10), dtype="float32"),
            w2: R.Tensor((7,), dtype="float32"),
            input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
        ) -> R.Tuple(R.Tensor((1, 3, 10, 7), dtype="float32")):
            # block 0
            with R.dataflow():
                lv: R.Tensor((30, 10), dtype="float32") = R.reshape(input_1, R.shape([30, 10]))
                lv1: R.Tensor((10, 7), dtype="float32") = R.permute_dims(w1, axes=[1, 0])
                lv2: R.Tensor((30, 7), dtype="float32") = R.matmul(lv, lv1, out_dtype="float32")
                lv3: R.Tensor((30, 7), dtype="float32") = R.add(w2, lv2)
                lv4: R.Tensor((1, 3, 10, 7), dtype="float32") = R.reshape(
                    lv3, R.shape([1, 3, 10, 7])
                )
                gv: R.Tuple(R.Tensor((1, 3, 10, 7), dtype="float32")) = (lv4,)
                R.output(gv)
            return gv

    class Dense2(Module):
        def __init__(self):
            super().__init__()
            self.linear = torch.nn.Linear(10, 7, bias=False)

        def forward(self, input):
            return self.linear(input)

    @tvm.script.ir_module
    class expected2:
        @R.function
        def main(
            w1: R.Tensor((7, 10), dtype="float32"),
            input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
        ) -> R.Tuple(R.Tensor((1, 3, 10, 7), dtype="float32")):
            # block 0
            with R.dataflow():
                lv: R.Tensor((10, 7), dtype="float32") = R.permute_dims(w1, axes=[1, 0])
                lv1: R.Tensor((30, 10), dtype="float32") = R.reshape(input_1, R.shape([30, 10]))
                lv2: R.Tensor((30, 7), dtype="float32") = R.matmul(lv1, lv, out_dtype="float32")
                lv3: R.Tensor((1, 3, 10, 7), dtype="float32") = R.reshape(
                    lv2, R.shape([1, 3, 10, 7])
                )
                gv: R.Tuple(R.Tensor((1, 3, 10, 7), dtype="float32")) = (lv3,)
                R.output(gv)
            return gv

    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)

    model = Dense1()
    binding = {"w1": model.linear.weight.detach().numpy(), "w2": model.linear.bias.detach().numpy()}
    verify_model(model, example_args, binding, expected1)

    model = Dense1Func()
    binding = {"w1": model.weight.detach().numpy(), "w2": model.bias.detach().numpy()}
    verify_model(model, example_args, binding, expected1)

    model = Dense2()
    binding = {"w1": model.linear.weight.detach().numpy()}
    verify_model(model, example_args, binding, expected2)


def test_maxpool1d():
    class MaxPool1d(Module):
        def __init__(self):
            super().__init__()
            self.pool = torch.nn.MaxPool1d(kernel_size=2)

        def forward(self, input):
            return self.pool(input)

    class MaxPool1d_functional(Module):
        def __init__(self):
            super().__init__()

        def forward(self, input):
            return torch.nn.functional.max_pool1d(input, kernel_size=2)

    class MaxPool1d2(Module):
        def __init__(self):
            super().__init__()
            self.pool = torch.nn.MaxPool1d(kernel_size=3, stride=2)

        def forward(self, input):
            return self.pool(input)

    @tvm.script.ir_module
    class expected1:
        @R.function
        def main(
            input_1: R.Tensor((1, 3, 8), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 4), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 1, 8), dtype="float32") = R.expand_dims(input_1, axis=[-2])
                lv1: R.Tensor((1, 3, 1, 4), dtype="float32") = R.nn.max_pool2d(
                    lv,
                    pool_size=[1, 2],
                    strides=[1, 2],
                    dilation=[1, 1],
                    padding=[0, 0, 0, 0],
                    layout="NCHW",
                    out_layout="NCHW",
                )
                lv2: R.Tensor((1, 3, 1, 4), dtype="float32") = R.zeros_like(lv1)
                lv3: R.Tuple(
                    R.Tensor((1, 3, 1, 4), dtype="float32"),
                    R.Tensor((1, 3, 1, 4), dtype="float32"),
                ) = (lv1, lv2)
                lv4: R.Tensor((1, 3, 1, 4), dtype="float32") = lv3[0]
                lv5: R.Tensor((1, 3, 4), dtype="float32") = R.squeeze(lv4, axis=[-2])
                gv: R.Tuple(R.Tensor((1, 3, 4), dtype="float32")) = (lv5,)
                R.output(gv)
            return gv

    @tvm.script.ir_module
    class expected2:
        @R.function
        def main(
            input_1: R.Tensor((1, 3, 8), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 4), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 1, 8), dtype="float32") = R.expand_dims(input_1, axis=[-2])
                lv1: R.Tensor((1, 3, 1, 4), dtype="float32") = R.nn.max_pool2d(
                    lv,
                    pool_size=[1, 2],
                    strides=[1, 2],
                    dilation=[1, 1],
                    padding=[0, 0, 0, 0],
                    layout="NCHW",
                    out_layout="NCHW",
                )
                lv2: R.Tensor((1, 3, 1, 4), dtype="float32") = R.zeros_like(lv1)
                lv3: R.Tuple(
                    R.Tensor((1, 3, 1, 4), dtype="float32"),
                    R.Tensor((1, 3, 1, 4), dtype="float32"),
                ) = (lv1, lv2)
                lv4: R.Tensor((1, 3, 1, 4), dtype="float32") = lv3[0]
                lv5: R.Tensor((1, 3, 4), dtype="float32") = R.squeeze(lv4, axis=[-2])
                gv: R.Tuple(R.Tensor((1, 3, 4), dtype="float32")) = (lv5,)
                R.output(gv)
            return gv

    @tvm.script.ir_module
    class expected3:
        @R.function
        def main(
            input_1: R.Tensor((1, 3, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 4), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 1, 10), dtype="float32") = R.expand_dims(input_1, axis=[-2])
                lv1: R.Tensor((1, 3, 1, 4), dtype="float32") = R.nn.max_pool2d(
                    lv,
                    pool_size=[1, 3],
                    strides=[1, 2],
                    dilation=[1, 1],
                    padding=[0, 0, 0, 0],
                    layout="NCHW",
                    out_layout="NCHW",
                )
                lv2: R.Tensor((1, 3, 1, 4), dtype="float32") = R.zeros_like(lv1)
                lv3: R.Tuple(
                    R.Tensor((1, 3, 1, 4), dtype="float32"),
                    R.Tensor((1, 3, 1, 4), dtype="float32"),
                ) = (lv1, lv2)
                lv4: R.Tensor((1, 3, 1, 4), dtype="float32") = lv3[0]
                lv5: R.Tensor((1, 3, 4), dtype="float32") = R.squeeze(lv4, axis=[-2])
                gv: R.Tuple(R.Tensor((1, 3, 4), dtype="float32")) = (lv5,)
                R.output(gv)
            return gv

    # Example inputs
    example_args1 = (torch.randn(1, 3, 8, dtype=torch.float32),)
    example_args2 = (torch.randn(1, 3, 8, dtype=torch.float32),)
    example_args3 = (torch.randn(1, 3, 10, dtype=torch.float32),)

    # Verify the models
    verify_model(MaxPool1d(), example_args1, {}, expected1)
    verify_model(MaxPool1d_functional(), example_args2, {}, expected2)
    verify_model(MaxPool1d2(), example_args3, {}, expected3)


def test_maxpool2d():
    class MaxPool2d(Module):
        def __init__(self):
            super().__init__()
            self.pool = torch.nn.MaxPool2d(kernel_size=[1, 1])

        def forward(self, input):
            return self.pool(input)

    class MaxPool2d_functional(Module):
        def __init__(self):
            super().__init__()

        def forward(self, input):
            return torch.nn.functional.max_pool2d(input, kernel_size=[1, 1])

    @tvm.script.ir_module
    class expected1:
        @R.function
        def main(
            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
            # block 0
            with R.dataflow():
                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.max_pool2d(
                    input_1,
                    pool_size=[1, 1],
                    strides=[1, 1],
                    dilation=[1, 1],
                    padding=[0, 0, 0, 0],
                    layout="NCHW",
                    out_layout="NCHW",
                )
                lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.zeros_like(lv)
                lv2: R.Tuple(
                    R.Tensor((1, 3, 10, 10), dtype="float32"),
                    R.Tensor((1, 3, 10, 10), dtype="float32"),
                ) = (lv, lv1)
                lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = lv2[0]
                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv3,)
                R.output(gv)
            return gv

    class MaxPool2d2(Module):
        def __init__(self):
            super().__init__()
            self.pool = torch.nn.MaxPool2d(kernel_size=[2, 2], dilation=[2, 3])

        def forward(self, input):
            return self.pool(input)

    @tvm.script.ir_module
    class expected2:
        @R.function
        def main(
            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 4, 4), dtype="float32")):
            # block 0
            with R.dataflow():
                lv: R.Tensor((1, 3, 4, 4), dtype="float32") = R.nn.max_pool2d(
                    input_1,
                    pool_size=[2, 2],
                    strides=[2, 2],
                    dilation=[2, 3],
                    padding=[0, 0, 0, 0],
                    layout="NCHW",
                    out_layout="NCHW",
                )
                lv1: R.Tensor((1, 3, 4, 4), dtype="float32") = R.zeros_like(lv)
                lv2: R.Tuple(
                    R.Tensor((1, 3, 4, 4), dtype="float32"), R.Tensor((1, 3, 4, 4), dtype="float32")
                ) = (lv, lv1)
                lv3: R.Tensor((1, 3, 4, 4), dtype="float32") = lv2[0]
                gv: R.Tuple(R.Tensor((1, 3, 4, 4), dtype="float32")) = (lv3,)
                R.output(gv)
            return gv

    class MaxPool2d3(Module):
        def __init__(self):
            super().__init__()
            self.pool = torch.nn.MaxPool2d(kernel_size=[4, 4], padding=2, stride=2)

        def forward(self, input):
            return self.pool(input)

    @tvm.script.ir_module
    class expected3:
        @R.function
        def main(
            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 6, 6), dtype="float32")):
            # block 0
            with R.dataflow():
                lv: R.Tensor((1, 3, 6, 6), dtype="float32") = R.nn.max_pool2d(
                    input_1,
                    pool_size=[4, 4],
                    strides=[2, 2],
                    dilation=[1, 1],
                    padding=[2, 2, 2, 2],
                    layout="NCHW",
                    out_layout="NCHW",
                )
                lv1: R.Tensor((1, 3, 6, 6), dtype="float32") = R.zeros_like(lv)
                lv2: R.Tuple(
                    R.Tensor((1, 3, 6, 6), dtype="float32"), R.Tensor((1, 3, 6, 6), dtype="float32")
                ) = (lv, lv1)
                lv3: R.Tensor((1, 3, 6, 6), dtype="float32") = lv2[0]
                gv: R.Tuple(R.Tensor((1, 3, 6, 6), dtype="float32")) = (lv3,)
                R.output(gv)
            return gv

    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
    verify_model(MaxPool2d(), example_args, {}, expected1)
    verify_model(MaxPool2d_functional(), example_args, {}, expected1)
    verify_model(MaxPool2d2(), example_args, {}, expected2)
    verify_model(MaxPool2d3(), example_args, {}, expected3)


def test_maxpool3d():
    class MaxPool3d(Module):
        def __init__(self):
            super().__init__()
            self.pool = torch.nn.MaxPool3d(kernel_size=[1, 1, 1])

        def forward(self, input):
            return self.pool(input)

    class MaxPool3d_functional(Module):
        def __init__(self):
            super().__init__()

        def forward(self, input):
            return torch.nn.functional.max_pool3d(input, kernel_size=[1, 1, 1])

    @tvm.script.ir_module
    class expected1:
        @R.function
        def main(
            input_1: R.Tensor((1, 3, 4, 4, 4), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 4, 4, 4), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 4, 4, 4), dtype="float32") = R.nn.max_pool3d(
                    input_1,
                    pool_size=[1, 1, 1],
                    strides=[1, 1, 1],
                    dilation=[1, 1, 1],
                    padding=[0, 0, 0, 0, 0, 0],
                    layout="NCDHW",
                    out_layout="NCDHW",
                )
                lv1: R.Tensor((1, 3, 4, 4, 4), dtype="float32") = R.zeros_like(lv)
                lv2: R.Tuple(
                    R.Tensor((1, 3, 4, 4, 4), dtype="float32"),
                    R.Tensor((1, 3, 4, 4, 4), dtype="float32"),
                ) = (lv, lv1)
                lv3: R.Tensor((1, 3, 4, 4, 4), dtype="float32") = lv2[0]
                gv: R.Tuple(R.Tensor((1, 3, 4, 4, 4), dtype="float32")) = (lv3,)
                R.output(gv)
            return gv

    class MaxPool3d2(Module):
        def __init__(self):
            super().__init__()
            self.pool = torch.nn.MaxPool3d(kernel_size=[2, 2, 2], dilation=[2, 2, 2])

        def forward(self, input):
            return self.pool(input)

    @tvm.script.ir_module
    class expected2:
        @R.function
        def main(
            input_1: R.Tensor((1, 3, 8, 8, 8), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 3, 3, 3), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 3, 3, 3), dtype="float32") = R.nn.max_pool3d(
                    input_1,
                    pool_size=[2, 2, 2],
                    strides=[2, 2, 2],
                    dilation=[2, 2, 2],
                    padding=[0, 0, 0, 0, 0, 0],
                    layout="NCDHW",
                    out_layout="NCDHW",
                )
                lv1: R.Tensor((1, 3, 3, 3, 3), dtype="float32") = R.zeros_like(lv)
                lv2: R.Tuple(
                    R.Tensor((1, 3, 3, 3, 3), dtype="float32"),
                    R.Tensor((1, 3, 3, 3, 3), dtype="float32"),
                ) = (lv, lv1)
                lv3: R.Tensor((1, 3, 3, 3, 3), dtype="float32") = lv2[0]
                gv: R.Tuple(R.Tensor((1, 3, 3, 3, 3), dtype="float32")) = (lv3,)
                R.output(gv)
            return gv

    class MaxPool3d3(Module):
        def __init__(self):
            super().__init__()
            self.pool = torch.nn.MaxPool3d(kernel_size=[3, 3, 3], padding=1, stride=2)

        def forward(self, input):
            return self.pool(input)

    @tvm.script.ir_module
    class expected3:
        @R.function
        def main(
            input_1: R.Tensor((1, 3, 10, 10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 5, 5, 5), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 5, 5, 5), dtype="float32") = R.nn.max_pool3d(
                    input_1,
                    pool_size=[3, 3, 3],
                    strides=[2, 2, 2],
                    dilation=[1, 1, 1],
                    padding=[1, 1, 1, 1, 1, 1],
                    layout="NCDHW",
                    out_layout="NCDHW",
                )
                lv1: R.Tensor((1, 3, 5, 5, 5), dtype="float32") = R.zeros_like(lv)
                lv2: R.Tuple(
                    R.Tensor((1, 3, 5, 5, 5), dtype="float32"),
                    R.Tensor((1, 3, 5, 5, 5), dtype="float32"),
                ) = (lv, lv1)
                lv3: R.Tensor((1, 3, 5, 5, 5), dtype="float32") = lv2[0]
                gv: R.Tuple(R.Tensor((1, 3, 5, 5, 5), dtype="float32")) = (lv3,)
                R.output(gv)
            return gv

    # Example input tensors
    example_args1 = (torch.randn(1, 3, 4, 4, 4, dtype=torch.float32),)
    example_args2 = (torch.randn(1, 3, 8, 8, 8, dtype=torch.float32),)
    example_args3 = (torch.randn(1, 3, 10, 10, 10, dtype=torch.float32),)

    # Verify the models with expected IR modules
    verify_model(MaxPool3d(), example_args1, {}, expected1)
    verify_model(MaxPool3d_functional(), example_args1, {}, expected1)
    verify_model(MaxPool3d2(), example_args2, {}, expected2)
    verify_model(MaxPool3d3(), example_args3, {}, expected3)


def test_scaled_dot_product_attention():
    class Attention1(Module):
        def forward(self, q, k, v):
            return torch.nn.functional.scaled_dot_product_attention(q, k, v)

    @I.ir_module
    class Expected1:
        @R.function
        def main(
            q: R.Tensor((32, 8, 128, 64), dtype="float32"),
            k: R.Tensor((32, 8, 128, 64), dtype="float32"),
            v: R.Tensor((32, 8, 128, 64), dtype="float32"),
        ) -> R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((32, 8, 128, 64), dtype="float32") = R.multiply(
                    q, R.const(0.35355338454246521, "float32")
                )
                lv1: R.Tensor((32, 8, 64, 128), dtype="float32") = R.permute_dims(
                    k, axes=[0, 1, 3, 2]
                )
                lv2: R.Tensor((32, 8, 64, 128), dtype="float32") = R.multiply(
                    lv1, R.const(0.35355338454246521, "float32")
                )
                lv3: R.Tensor((32, 8, 128, 64), dtype="float32") = R.broadcast_to(
                    lv, R.shape([32, 8, 128, 64])
                )
                lv4: R.Tensor((256, 128, 64), dtype="float32") = R.reshape(
                    lv3, R.shape([256, 128, 64])
                )
                lv5: R.Tensor((32, 8, 64, 128), dtype="float32") = R.broadcast_to(
                    lv2, R.shape([32, 8, 64, 128])
                )
                lv6: R.Tensor((256, 64, 128), dtype="float32") = R.reshape(
                    lv5, R.shape([256, 64, 128])
                )
                lv7: R.Tensor((256, 128, 128), dtype="float32") = R.matmul(
                    lv4, lv6, out_dtype="float32"
                )
                lv8: R.Tensor((32, 8, 128, 128), dtype="float32") = R.reshape(
                    lv7, R.shape([32, 8, 128, 128])
                )
                lv9: R.Tensor((32, 8, 128, 128), dtype="float32") = R.nn.softmax(lv8, axis=-1)
                lv10: R.Tensor((32, 8, 128, 128), dtype="bool") = R.equal(
                    lv8, R.const(float("-inf"), "float32")
                )
                lv11: R.Tensor((32, 8, 128, 128), dtype="bool") = R.logical_not(lv10)
                lv12: R.Tensor((32, 8, 128, 1), dtype="bool") = R.max(
                    lv11, axis=[-1], keepdims=True
                )
                lv13: R.Tensor((32, 8, 128, 1), dtype="bool") = R.logical_not(lv12)
                lv14: R.Tensor((32, 8, 128, 128), dtype="float32") = R.full_like(
                    lv9, R.const(0, "int32"), dtype="void"
                )
                lv15: R.Tensor((32, 8, 128, 128), dtype="float32") = R.where(lv13, lv14, lv9)
                lv16: R.Tensor((32, 8, 128, 128), dtype="float32") = R.broadcast_to(
                    lv15, R.shape([32, 8, 128, 128])
                )
                lv17: R.Tensor((256, 128, 128), dtype="float32") = R.reshape(
                    lv16, R.shape([256, 128, 128])
                )
                lv18: R.Tensor((32, 8, 128, 64), dtype="float32") = R.broadcast_to(
                    v, R.shape([32, 8, 128, 64])
                )
                lv19: R.Tensor((256, 128, 64), dtype="float32") = R.reshape(
                    lv18, R.shape([256, 128, 64])
                )
                lv20: R.Tensor((256, 128, 64), dtype="float32") = R.matmul(
                    lv17, lv19, out_dtype="float32"
                )
                lv21: R.Tensor((32, 8, 128, 64), dtype="float32") = R.reshape(
                    lv20, R.shape([32, 8, 128, 64])
                )
                lv22: R.Tensor((128, 32, 8, 64), dtype="float32") = R.permute_dims(
                    lv21, axes=[2, 0, 1, 3]
                )
                lv23: R.Tensor((32, 8, 128, 64), dtype="float32") = R.permute_dims(
                    lv22, axes=[1, 2, 0, 3]
                )
                gv: R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")) = (lv23,)
                R.output(gv)
            return gv

    class Attention2(Module):
        def forward(self, q, k, v, mask):
            return torch.nn.functional.scaled_dot_product_attention(q, k, v, mask)

    @I.ir_module
    class Expected2:
        @R.function
        def main(
            q: R.Tensor((32, 8, 128, 64), dtype="float32"),
            k: R.Tensor((32, 8, 128, 64), dtype="float32"),
            v: R.Tensor((32, 8, 128, 64), dtype="float32"),
            mask: R.Tensor((32, 8, 128, 128), dtype="float32"),
        ) -> R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((32, 8, 128, 64), dtype="float32") = R.multiply(
                    q, R.const(0.35355338454246521, "float32")
                )
                lv1: R.Tensor((32, 8, 64, 128), dtype="float32") = R.permute_dims(
                    k, axes=[0, 1, 3, 2]
                )
                lv2: R.Tensor((32, 8, 64, 128), dtype="float32") = R.multiply(
                    lv1, R.const(0.35355338454246521, "float32")
                )
                lv3: R.Tensor((32, 8, 128, 64), dtype="float32") = R.broadcast_to(
                    lv, R.shape([32, 8, 128, 64])
                )
                lv4: R.Tensor((256, 128, 64), dtype="float32") = R.reshape(
                    lv3, R.shape([256, 128, 64])
                )
                lv5: R.Tensor((32, 8, 64, 128), dtype="float32") = R.broadcast_to(
                    lv2, R.shape([32, 8, 64, 128])
                )
                lv6: R.Tensor((256, 64, 128), dtype="float32") = R.reshape(
                    lv5, R.shape([256, 64, 128])
                )
                lv7: R.Tensor((256, 128, 128), dtype="float32") = R.matmul(
                    lv4, lv6, out_dtype="float32"
                )
                lv8: R.Tensor((32, 8, 128, 128), dtype="float32") = R.reshape(
                    lv7, R.shape([32, 8, 128, 128])
                )
                lv9: R.Tensor((32, 8, 128, 128), dtype="float32") = R.add(lv8, mask)
                lv10: R.Tensor((32, 8, 128, 128), dtype="float32") = R.nn.softmax(lv9, axis=-1)
                lv11: R.Tensor((32, 8, 128, 128), dtype="bool") = R.equal(
                    lv9, R.const(float("-inf"), "float32")
                )
                lv12: R.Tensor((32, 8, 128, 128), dtype="bool") = R.logical_not(lv11)
                lv13: R.Tensor((32, 8, 128, 1), dtype="bool") = R.max(
                    lv12, axis=[-1], keepdims=True
                )
                lv14: R.Tensor((32, 8, 128, 1), dtype="bool") = R.logical_not(lv13)
                lv15: R.Tensor((32, 8, 128, 128), dtype="float32") = R.full_like(
                    lv10, R.const(0, "int32"), dtype="void"
                )
                lv16: R.Tensor((32, 8, 128, 128), dtype="float32") = R.where(lv14, lv15, lv10)
                lv17: R.Tensor((32, 8, 128, 128), dtype="float32") = R.broadcast_to(
                    lv16, R.shape([32, 8, 128, 128])
                )
                lv18: R.Tensor((256, 128, 128), dtype="float32") = R.reshape(
                    lv17, R.shape([256, 128, 128])
                )
                lv19: R.Tensor((32, 8, 128, 64), dtype="float32") = R.broadcast_to(
                    v, R.shape([32, 8, 128, 64])
                )
                lv20: R.Tensor((256, 128, 64), dtype="float32") = R.reshape(
                    lv19, R.shape([256, 128, 64])
                )
                lv21: R.Tensor((256, 128, 64), dtype="float32") = R.matmul(
                    lv18, lv20, out_dtype="float32"
                )
                lv22: R.Tensor((32, 8, 128, 64), dtype="float32") = R.reshape(
                    lv21, R.shape([32, 8, 128, 64])
                )
                lv23: R.Tensor((128, 32, 8, 64), dtype="float32") = R.permute_dims(
                    lv22, axes=[2, 0, 1, 3]
                )
                lv24: R.Tensor((32, 8, 128, 64), dtype="float32") = R.permute_dims(
                    lv23, axes=[1, 2, 0, 3]
                )
                gv: R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")) = (lv24,)
                R.output(gv)
            return gv

    verify_model(
        Attention1(),
        (
            torch.randn(32, 8, 128, 64, dtype=torch.float32),
            torch.randn(32, 8, 128, 64, dtype=torch.float32),
            torch.randn(32, 8, 128, 64, dtype=torch.float32),
        ),
        {},
        Expected1,
        run_ep_decomposition=True,
    )

    verify_model(
        Attention2(),
        (
            torch.randn(32, 8, 128, 64, dtype=torch.float32),
            torch.randn(32, 8, 128, 64, dtype=torch.float32),
            torch.randn(32, 8, 128, 64, dtype=torch.float32),
            torch.randn(32, 8, 128, 128, dtype=torch.float32),
        ),
        {},
        Expected2,
        run_ep_decomposition=True,
    )


def test_unbind():
    class Unbind1(Module):
        def forward(self, data):
            return torch.unbind(data)

    @tvm.script.ir_module
    class expected1:
        @R.function
        def main(
            data: R.Tensor((3, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(
            R.Tensor((3, 10, 10), dtype="float32"),
            R.Tensor((3, 10, 10), dtype="float32"),
            R.Tensor((3, 10, 10), dtype="float32"),
        ):
            # block 0
            with R.dataflow():
                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.strided_slice(
                    data,
                    (R.prim_value(0),),
                    (R.prim_value(0),),
                    (R.prim_value(1),),
                    (R.prim_value(1),),
                    assume_inbound=False,
                )
                lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.strided_slice(
                    data,
                    (R.prim_value(0),),
                    (R.prim_value(1),),
                    (R.prim_value(2),),
                    (R.prim_value(1),),
                    assume_inbound=False,
                )
                lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.strided_slice(
                    data,
                    (R.prim_value(0),),
                    (R.prim_value(2),),
                    (R.prim_value(3),),
                    (R.prim_value(1),),
                    assume_inbound=False,
                )
                lv3: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv, axis=[0])
                lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[0])
                lv5: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv2, axis=[0])
                gv: R.Tuple(
                    R.Tensor((3, 10, 10), dtype="float32"),
                    R.Tensor((3, 10, 10), dtype="float32"),
                    R.Tensor((3, 10, 10), dtype="float32"),
                ) = (lv3, lv4, lv5)
                R.output(gv)
            return gv

    class Unbind2(Module):
        def forward(self, data):
            return torch.unbind(data, dim=1)

    @tvm.script.ir_module
    class expected2:
        @R.function
        def main(
            data: R.Tensor((3, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(
            R.Tensor((3, 10, 10), dtype="float32"),
            R.Tensor((3, 10, 10), dtype="float32"),
            R.Tensor((3, 10, 10), dtype="float32"),
        ):
            # block 0
            with R.dataflow():
                lv: R.Tensor((3, 1, 10, 10), dtype="float32") = R.strided_slice(
                    data,
                    (R.prim_value(1),),
                    (R.prim_value(0),),
                    (R.prim_value(1),),
                    (R.prim_value(1),),
                    assume_inbound=False,
                )
                lv1: R.Tensor((3, 1, 10, 10), dtype="float32") = R.strided_slice(
                    data,
                    (R.prim_value(1),),
                    (R.prim_value(1),),
                    (R.prim_value(2),),
                    (R.prim_value(1),),
                    assume_inbound=False,
                )
                lv2: R.Tensor((3, 1, 10, 10), dtype="float32") = R.strided_slice(
                    data,
                    (R.prim_value(1),),
                    (R.prim_value(2),),
                    (R.prim_value(3),),
                    (R.prim_value(1),),
                    assume_inbound=False,
                )
                lv3: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv, axis=[1])
                lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[1])
                lv5: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv2, axis=[1])
                gv: R.Tuple(
                    R.Tensor((3, 10, 10), dtype="float32"),
                    R.Tensor((3, 10, 10), dtype="float32"),
                    R.Tensor((3, 10, 10), dtype="float32"),
                ) = (lv3, lv4, lv5)
                R.output(gv)
            return gv

    @tvm.script.ir_module
    class expected3:
        @R.function
        def main(
            data: R.Tensor((3, 1, 3), dtype="float32")
        ) -> R.Tuple(R.Tensor((3, 3), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((3, 1, 3), dtype="float32") = R.strided_slice(
                    data,
                    (R.prim_value(1),),
                    (R.prim_value(0),),
                    (R.prim_value(1),),
                    (R.prim_value(1),),
                    assume_inbound=False,
                )
                lv1: R.Tensor((3, 3), dtype="float32") = R.squeeze(lv, axis=[1])
                gv: R.Tuple(R.Tensor((3, 3), dtype="float32")) = (lv1,)
                R.output(gv)
            return gv

    example_args = (torch.randn(3, 3, 10, 10, dtype=torch.float32),)
    verify_model(Unbind1(), example_args, {}, expected1)
    verify_model(Unbind2(), example_args, {}, expected2)
    single_dim_args = (torch.randn(3, 1, 3, dtype=torch.float32),)
    verify_model(Unbind2(), single_dim_args, {}, expected3)


def test_interpolate():
    class InterpolateBilinear(Module):
        def forward(self, input):
            return torch.nn.functional.interpolate(input, (224, 224), mode="bilinear")

    @tvm.script.ir_module
    class expected_bilinear:
        @R.function
        def main(
            input: R.Tensor((1, 3, 112, 112), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")):
            # block 0
            with R.dataflow():
                lv: R.Tensor((1, 3, 224, 224), dtype="float32") = R.image.resize2d(
                    input,
                    R.shape([224, 224]),
                    roi=[T.float32(0.0), T.float32(0.0), T.float32(0.0), T.float32(0.0)],
                    layout="NCHW",
                    method="linear",
                    coordinate_transformation_mode="half_pixel",
                    rounding_method="round",
                    cubic_alpha=-0.75,
                    cubic_exclude=0,
                    extrapolation_value=0.0,
                    out_dtype="void",
                )
                gv: R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    class InterpolateNearest(Module):
        def forward(self, input):
            return torch.nn.functional.interpolate(input, (224, 224), mode="nearest")

    @tvm.script.ir_module
    class expected_nearest:
        @R.function
        def main(
            input: R.Tensor((1, 3, 112, 112), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")):
            # block 0
            with R.dataflow():
                lv: R.Tensor((1, 3, 224, 224), dtype="float32") = R.image.resize2d(
                    input,
                    R.shape([224, 224]),
                    roi=[T.float32(0.0), T.float32(0.0), T.float32(0.0), T.float32(0.0)],
                    layout="NCHW",
                    method="nearest_neighbor",
                    coordinate_transformation_mode="half_pixel",
                    rounding_method="round",
                    cubic_alpha=-0.75,
                    cubic_exclude=0,
                    extrapolation_value=0.0,
                    out_dtype="void",
                )
                gv: R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    class InterpolateBicubic(Module):
        def forward(self, input):
            return torch.nn.functional.interpolate(input, (224, 224), mode="bicubic")

    @tvm.script.ir_module
    class expected_bicubic:
        @R.function
        def main(
            input: R.Tensor((1, 3, 112, 112), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 112, 112), dtype="float32") = R.astype(input, dtype="float32")
                lv1: R.Tensor((1, 3, 112, 112), dtype="float32") = R.astype(lv, dtype="float32")
                lv2: R.Tensor((224,), dtype="int64") = R.arange(
                    R.prim_value(0), R.prim_value(224), R.prim_value(1), dtype="int64"
                )
                lv3: R.Tensor((224,), dtype="float32") = R.astype(lv2, dtype="float32")
                lv4: R.Tensor((224,), dtype="int64") = R.arange(
                    R.prim_value(0), R.prim_value(224), R.prim_value(1), dtype="int64"
                )
                lv5: R.Tensor((224,), dtype="float32") = R.astype(lv4, dtype="float32")
                lv6: R.Tensor((224,), dtype="float32") = R.add(lv5, R.const(0.5, "float32"))
                lv7: R.Tensor((224,), dtype="float32") = R.multiply(lv6, R.const(0.5, "float32"))
                lv8: R.Tensor((224,), dtype="float32") = R.subtract(lv7, R.const(0.5, "float32"))
                lv9: R.Tensor((224,), dtype="float32") = R.add(lv3, R.const(0.5, "float32"))
                lv10: R.Tensor((224,), dtype="float32") = R.multiply(lv9, R.const(0.5, "float32"))
                lv11: R.Tensor((224,), dtype="float32") = R.subtract(lv10, R.const(0.5, "float32"))
                lv12: R.Tensor((224, 1), dtype="float32") = R.expand_dims(lv11, axis=[-1])
                lv13: R.Tensor((224,), dtype="float32") = R.floor(lv8)
                lv14: R.Tensor((224, 1), dtype="float32") = R.floor(lv12)
                lv15: R.Tensor((224, 1), dtype="float32") = R.subtract(lv12, lv14)
                lv16: R.Tensor((224, 1), dtype="float32") = R.clip(
                    lv15, R.prim_value(T.float64(0.0)), R.prim_value(T.float64(1.0))
                )
                lv17: R.Tensor((224,), dtype="float32") = R.subtract(lv8, lv13)
                lv18: R.Tensor((224,), dtype="float32") = R.clip(
                    lv17, R.prim_value(T.float64(0.0)), R.prim_value(T.float64(1.0))
                )
                lv19: R.Tensor((224,), dtype="int64") = R.astype(lv13, dtype="int64")
                lv20: R.Tensor((224, 1), dtype="int64") = R.astype(lv14, dtype="int64")
                lv21: R.Tensor((224, 1), dtype="int64") = R.subtract(lv20, R.const(1, "int64"))
                lv22: R.Tensor((224, 1), dtype="int64") = R.add(lv20, R.const(1, "int64"))
                lv23: R.Tensor((224, 1), dtype="int64") = R.add(lv20, R.const(2, "int64"))
                lv24: R.Tensor((224,), dtype="int64") = R.subtract(lv19, R.const(1, "int64"))
                lv25: R.Tensor((224,), dtype="int64") = R.add(lv19, R.const(1, "int64"))
                lv26: R.Tensor((224,), dtype="int64") = R.add(lv19, R.const(2, "int64"))
                lv27: R.Tensor((224,), dtype="float32") = R.subtract(R.const(1.0, "float32"), lv18)
                lv28: R.Tensor((448,), dtype="float32") = R.concat((lv18, lv27), axis=0)
                lv29: R.Tensor((2, 224), dtype="float32") = R.reshape(lv28, R.shape([2, 224]))
                lv30: R.Tensor((224,), dtype="float32") = R.add(lv18, R.const(1.0, "float32"))
                lv31: R.Tensor((224,), dtype="float32") = R.subtract(R.const(2.0, "float32"), lv18)
                lv32: R.Tensor((448,), dtype="float32") = R.concat((lv30, lv31), axis=0)
                lv33: R.Tensor((2, 224), dtype="float32") = R.reshape(lv32, R.shape([2, 224]))
                lv34: R.Tensor((2, 224), dtype="float32") = R.multiply(
                    lv33, R.const(-0.75, "float32")
                )
                lv35: R.Tensor((2, 224), dtype="float32") = R.subtract(
                    lv34, R.const(-3.75, "float32")
                )
                lv36: R.Tensor((2, 224), dtype="float32") = R.multiply(lv35, lv33)
                lv37: R.Tensor((2, 224), dtype="float32") = R.add(lv36, R.const(-6.0, "float32"))
                lv38: R.Tensor((2, 224), dtype="float32") = R.multiply(lv37, lv33)
                lv39: R.Tensor((2, 224), dtype="float32") = R.subtract(
                    lv38, R.const(-3.0, "float32")
                )
                lv40: R.Tensor((2, 224), dtype="float32") = R.multiply(
                    lv29, R.const(1.25, "float32")
                )
                lv41: R.Tensor((2, 224), dtype="float32") = R.subtract(
                    lv40, R.const(2.25, "float32")
                )
                lv42: R.Tensor((2, 224), dtype="float32") = R.multiply(lv41, lv29)
                lv43: R.Tensor((2, 224), dtype="float32") = R.multiply(lv42, lv29)
                lv44: R.Tensor((2, 224), dtype="float32") = R.add(lv43, R.const(1.0, "float32"))
                lv45: R.Tensor((1, 224), dtype="float32") = R.strided_slice(
                    lv39,
                    (R.prim_value(0),),
                    (R.prim_value(0),),
                    (R.prim_value(1),),
                    (R.prim_value(1),),
                    assume_inbound=False,
                )
                lv46: R.Tensor((1, 224), dtype="float32") = R.strided_slice(
                    lv39,
                    (R.prim_value(0),),
                    (R.prim_value(1),),
                    (R.prim_value(2),),
                    (R.prim_value(1),),
                    assume_inbound=False,
                )
                lv47: R.Tensor((224,), dtype="float32") = R.squeeze(lv45, axis=[0])
                lv48: R.Tensor((224,), dtype="float32") = R.squeeze(lv46, axis=[0])
                lv49: R.Tensor((1, 224), dtype="float32") = R.strided_slice(
                    lv44,
                    (R.prim_value(0),),
                    (R.prim_value(0),),
                    (R.prim_value(1),),
                    (R.prim_value(1),),
                    assume_inbound=False,
                )
                lv50: R.Tensor((1, 224), dtype="float32") = R.strided_slice(
                    lv44,
                    (R.prim_value(0),),
                    (R.prim_value(1),),
                    (R.prim_value(2),),
                    (R.prim_value(1),),
                    assume_inbound=False,
                )
                lv51: R.Tensor((224,), dtype="float32") = R.squeeze(lv49, axis=[0])
                lv52: R.Tensor((224,), dtype="float32") = R.squeeze(lv50, axis=[0])
                lv53: R.Tensor((224, 1), dtype="float32") = R.subtract(
                    R.const(1.0, "float32"), lv16
                )
                lv54: R.Tensor((448, 1), dtype="float32") = R.concat((lv16, lv53), axis=0)
                lv55: R.Tensor((2, 224, 1), dtype="float32") = R.reshape(lv54, R.shape([2, 224, 1]))
                lv56: R.Tensor((224, 1), dtype="float32") = R.add(lv16, R.const(1.0, "float32"))
                lv57: R.Tensor((224, 1), dtype="float32") = R.subtract(
                    R.const(2.0, "float32"), lv16
                )
                lv58: R.Tensor((448, 1), dtype="float32") = R.concat((lv56, lv57), axis=0)
                lv59: R.Tensor((2, 224, 1), dtype="float32") = R.reshape(lv58, R.shape([2, 224, 1]))
                lv60: R.Tensor((2, 224, 1), dtype="float32") = R.multiply(
                    lv59, R.const(-0.75, "float32")
                )
                lv61: R.Tensor((2, 224, 1), dtype="float32") = R.subtract(
                    lv60, R.const(-3.75, "float32")
                )
                lv62: R.Tensor((2, 224, 1), dtype="float32") = R.multiply(lv61, lv59)
                lv63: R.Tensor((2, 224, 1), dtype="float32") = R.add(lv62, R.const(-6.0, "float32"))
                lv64: R.Tensor((2, 224, 1), dtype="float32") = R.multiply(lv63, lv59)
                lv65: R.Tensor((2, 224, 1), dtype="float32") = R.subtract(
                    lv64, R.const(-3.0, "float32")
                )
                lv66: R.Tensor((2, 224, 1), dtype="float32") = R.multiply(
                    lv55, R.const(1.25, "float32")
                )
                lv67: R.Tensor((2, 224, 1), dtype="float32") = R.subtract(
                    lv66, R.const(2.25, "float32")
                )
                lv68: R.Tensor((2, 224, 1), dtype="float32") = R.multiply(lv67, lv55)
                lv69: R.Tensor((2, 224, 1), dtype="float32") = R.multiply(lv68, lv55)
                lv70: R.Tensor((2, 224, 1), dtype="float32") = R.add(lv69, R.const(1.0, "float32"))
                lv71: R.Tensor((1, 224, 1), dtype="float32") = R.strided_slice(
                    lv65,
                    (R.prim_value(0),),
                    (R.prim_value(0),),
                    (R.prim_value(1),),
                    (R.prim_value(1),),
                    assume_inbound=False,
                )
                lv72: R.Tensor((1, 224, 1), dtype="float32") = R.strided_slice(
                    lv65,
                    (R.prim_value(0),),
                    (R.prim_value(1),),
                    (R.prim_value(2),),
                    (R.prim_value(1),),
                    assume_inbound=False,
                )
                lv73: R.Tensor((224, 1), dtype="float32") = R.squeeze(lv71, axis=[0])
                lv74: R.Tensor((224, 1), dtype="float32") = R.squeeze(lv72, axis=[0])
                lv75: R.Tensor((1, 224, 1), dtype="float32") = R.strided_slice(
                    lv70,
                    (R.prim_value(0),),
                    (R.prim_value(0),),
                    (R.prim_value(1),),
                    (R.prim_value(1),),
                    assume_inbound=False,
                )
                lv76: R.Tensor((1, 224, 1), dtype="float32") = R.strided_slice(
                    lv70,
                    (R.prim_value(0),),
                    (R.prim_value(1),),
                    (R.prim_value(2),),
                    (R.prim_value(1),),
                    assume_inbound=False,
                )
                lv77: R.Tensor((224, 1), dtype="float32") = R.squeeze(lv75, axis=[0])
                lv78: R.Tensor((224, 1), dtype="float32") = R.squeeze(lv76, axis=[0])
                lv79: R.Tensor((224, 1), dtype="int64") = R.clip(
                    lv21, R.prim_value(0), R.prim_value(111)
                )
                lv80: R.Tensor((224,), dtype="int64") = R.clip(
                    lv24, R.prim_value(0), R.prim_value(111)
                )
                lv81: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take(
                    lv1, lv80, axis=3, mode="fast"
                )
                lv82: R.Tensor((224,), dtype="int64") = R.squeeze(lv79, axis=None)
                lv83: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take(
                    lv81, lv82, axis=2, mode="fast"
                )
                lv84: R.Tensor((224, 1), dtype="int64") = R.clip(
                    lv21, R.prim_value(0), R.prim_value(111)
                )
                lv85: R.Tensor((224,), dtype="int64") = R.clip(
                    lv19, R.prim_value(0), R.prim_value(111)
                )
                lv86: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take(
                    lv1, lv85, axis=3, mode="fast"
                )
                lv87: R.Tensor((224,), dtype="int64") = R.squeeze(lv84, axis=None)
                lv88: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take(
                    lv86, lv87, axis=2, mode="fast"
                )
                lv89: R.Tensor((224, 1), dtype="int64") = R.clip(
                    lv21, R.prim_value(0), R.prim_value(111)
                )
                lv90: R.Tensor((224,), dtype="int64") = R.clip(
                    lv25, R.prim_value(0), R.prim_value(111)
                )
                lv91: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take(
                    lv1, lv90, axis=3, mode="fast"
                )
                lv92: R.Tensor((224,), dtype="int64") = R.squeeze(lv89, axis=None)
                lv93: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take(
                    lv91, lv92, axis=2, mode="fast"
                )
                lv94: R.Tensor((224, 1), dtype="int64") = R.clip(
                    lv21, R.prim_value(0), R.prim_value(111)
                )
                lv95: R.Tensor((224,), dtype="int64") = R.clip(
                    lv26, R.prim_value(0), R.prim_value(111)
                )
                lv96: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take(
                    lv1, lv95, axis=3, mode="fast"
                )
                lv97: R.Tensor((224,), dtype="int64") = R.squeeze(lv94, axis=None)
                lv98: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take(
                    lv96, lv97, axis=2, mode="fast"
                )
                lv99: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv83, lv47)
                lv100: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv88, lv51)
                lv101: R.Tensor((1, 3, 224, 224), dtype="float32") = R.add(lv99, lv100)
                lv102: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv93, lv52)
                lv103: R.Tensor((1, 3, 224, 224), dtype="float32") = R.add(lv101, lv102)
                lv104: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv98, lv48)
                lv105: R.Tensor((1, 3, 224, 224), dtype="float32") = R.add(lv103, lv104)
                lv106: R.Tensor((224, 1), dtype="int64") = R.clip(
                    lv20, R.prim_value(0), R.prim_value(111)
                )
                lv107: R.Tensor((224,), dtype="int64") = R.clip(
                    lv24, R.prim_value(0), R.prim_value(111)
                )
                lv108: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take(
                    lv1, lv107, axis=3, mode="fast"
                )
                lv109: R.Tensor((224,), dtype="int64") = R.squeeze(lv106, axis=None)
                lv110: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take(
                    lv108, lv109, axis=2, mode="fast"
                )
                lv111: R.Tensor((224, 1), dtype="int64") = R.clip(
                    lv20, R.prim_value(0), R.prim_value(111)
                )
                lv112: R.Tensor((224,), dtype="int64") = R.clip(
                    lv19, R.prim_value(0), R.prim_value(111)
                )
                lv113: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take(
                    lv1, lv112, axis=3, mode="fast"
                )
                lv114: R.Tensor((224,), dtype="int64") = R.squeeze(lv111, axis=None)
                lv115: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take(
                    lv113, lv114, axis=2, mode="fast"
                )
                lv116: R.Tensor((224, 1), dtype="int64") = R.clip(
                    lv20, R.prim_value(0), R.prim_value(111)
                )
                lv117: R.Tensor((224,), dtype="int64") = R.clip(
                    lv25, R.prim_value(0), R.prim_value(111)
                )
                lv118: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take(
                    lv1, lv117, axis=3, mode="fast"
                )
                lv119: R.Tensor((224,), dtype="int64") = R.squeeze(lv116, axis=None)
                lv120: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take(
                    lv118, lv119, axis=2, mode="fast"
                )
                lv121: R.Tensor((224, 1), dtype="int64") = R.clip(
                    lv20, R.prim_value(0), R.prim_value(111)
                )
                lv122: R.Tensor((224,), dtype="int64") = R.clip(
                    lv26, R.prim_value(0), R.prim_value(111)
                )
                lv123: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take(
                    lv1, lv122, axis=3, mode="fast"
                )
                lv124: R.Tensor((224,), dtype="int64") = R.squeeze(lv121, axis=None)
                lv125: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take(
                    lv123, lv124, axis=2, mode="fast"
                )
                lv126: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv110, lv47)
                lv127: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv115, lv51)
                lv128: R.Tensor((1, 3, 224, 224), dtype="float32") = R.add(lv126, lv127)
                lv129: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv120, lv52)
                lv130: R.Tensor((1, 3, 224, 224), dtype="float32") = R.add(lv128, lv129)
                lv131: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv125, lv48)
                lv132: R.Tensor((1, 3, 224, 224), dtype="float32") = R.add(lv130, lv131)
                lv133: R.Tensor((224, 1), dtype="int64") = R.clip(
                    lv22, R.prim_value(0), R.prim_value(111)
                )
                lv134: R.Tensor((224,), dtype="int64") = R.clip(
                    lv24, R.prim_value(0), R.prim_value(111)
                )
                lv135: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take(
                    lv1, lv134, axis=3, mode="fast"
                )
                lv136: R.Tensor((224,), dtype="int64") = R.squeeze(lv133, axis=None)
                lv137: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take(
                    lv135, lv136, axis=2, mode="fast"
                )
                lv138: R.Tensor((224, 1), dtype="int64") = R.clip(
                    lv22, R.prim_value(0), R.prim_value(111)
                )
                lv139: R.Tensor((224,), dtype="int64") = R.clip(
                    lv19, R.prim_value(0), R.prim_value(111)
                )
                lv140: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take(
                    lv1, lv139, axis=3, mode="fast"
                )
                lv141: R.Tensor((224,), dtype="int64") = R.squeeze(lv138, axis=None)
                lv142: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take(
                    lv140, lv141, axis=2, mode="fast"
                )
                lv143: R.Tensor((224, 1), dtype="int64") = R.clip(
                    lv22, R.prim_value(0), R.prim_value(111)
                )
                lv144: R.Tensor((224,), dtype="int64") = R.clip(
                    lv25, R.prim_value(0), R.prim_value(111)
                )
                lv145: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take(
                    lv1, lv144, axis=3, mode="fast"
                )
                lv146: R.Tensor((224,), dtype="int64") = R.squeeze(lv143, axis=None)
                lv147: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take(
                    lv145, lv146, axis=2, mode="fast"
                )
                lv148: R.Tensor((224, 1), dtype="int64") = R.clip(
                    lv22, R.prim_value(0), R.prim_value(111)
                )
                lv149: R.Tensor((224,), dtype="int64") = R.clip(
                    lv26, R.prim_value(0), R.prim_value(111)
                )
                lv150: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take(
                    lv1, lv149, axis=3, mode="fast"
                )
                lv151: R.Tensor((224,), dtype="int64") = R.squeeze(lv148, axis=None)
                lv152: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take(
                    lv150, lv151, axis=2, mode="fast"
                )
                lv153: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv137, lv47)
                lv154: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv142, lv51)
                lv155: R.Tensor((1, 3, 224, 224), dtype="float32") = R.add(lv153, lv154)
                lv156: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv147, lv52)
                lv157: R.Tensor((1, 3, 224, 224), dtype="float32") = R.add(lv155, lv156)
                lv158: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv152, lv48)
                lv159: R.Tensor((1, 3, 224, 224), dtype="float32") = R.add(lv157, lv158)
                lv160: R.Tensor((224, 1), dtype="int64") = R.clip(
                    lv23, R.prim_value(0), R.prim_value(111)
                )
                lv161: R.Tensor((224,), dtype="int64") = R.clip(
                    lv24, R.prim_value(0), R.prim_value(111)
                )
                lv162: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take(
                    lv1, lv161, axis=3, mode="fast"
                )
                lv163: R.Tensor((224,), dtype="int64") = R.squeeze(lv160, axis=None)
                lv164: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take(
                    lv162, lv163, axis=2, mode="fast"
                )
                lv165: R.Tensor((224, 1), dtype="int64") = R.clip(
                    lv23, R.prim_value(0), R.prim_value(111)
                )
                lv166: R.Tensor((224,), dtype="int64") = R.clip(
                    lv19, R.prim_value(0), R.prim_value(111)
                )
                lv167: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take(
                    lv1, lv166, axis=3, mode="fast"
                )
                lv168: R.Tensor((224,), dtype="int64") = R.squeeze(lv165, axis=None)
                lv169: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take(
                    lv167, lv168, axis=2, mode="fast"
                )
                lv170: R.Tensor((224, 1), dtype="int64") = R.clip(
                    lv23, R.prim_value(0), R.prim_value(111)
                )
                lv171: R.Tensor((224,), dtype="int64") = R.clip(
                    lv25, R.prim_value(0), R.prim_value(111)
                )
                lv172: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take(
                    lv1, lv171, axis=3, mode="fast"
                )
                lv173: R.Tensor((224,), dtype="int64") = R.squeeze(lv170, axis=None)
                lv174: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take(
                    lv172, lv173, axis=2, mode="fast"
                )
                lv175: R.Tensor((224, 1), dtype="int64") = R.clip(
                    lv23, R.prim_value(0), R.prim_value(111)
                )
                lv176: R.Tensor((224,), dtype="int64") = R.clip(
                    lv26, R.prim_value(0), R.prim_value(111)
                )
                lv177: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take(
                    lv1, lv176, axis=3, mode="fast"
                )
                lv178: R.Tensor((224,), dtype="int64") = R.squeeze(lv175, axis=None)
                lv179: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take(
                    lv177, lv178, axis=2, mode="fast"
                )
                lv180: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv164, lv47)
                lv181: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv169, lv51)
                lv182: R.Tensor((1, 3, 224, 224), dtype="float32") = R.add(lv180, lv181)
                lv183: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv174, lv52)
                lv184: R.Tensor((1, 3, 224, 224), dtype="float32") = R.add(lv182, lv183)
                lv185: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv179, lv48)
                lv186: R.Tensor((1, 3, 224, 224), dtype="float32") = R.add(lv184, lv185)
                lv187: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv105, lv73)
                lv188: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv132, lv77)
                lv189: R.Tensor((1, 3, 224, 224), dtype="float32") = R.add(lv187, lv188)
                lv190: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv159, lv78)
                lv191: R.Tensor((1, 3, 224, 224), dtype="float32") = R.add(lv189, lv190)
                lv192: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv186, lv74)
                lv193: R.Tensor((1, 3, 224, 224), dtype="float32") = R.add(lv191, lv192)
                lv194: R.Tensor((1, 3, 224, 224), dtype="float32") = R.astype(
                    lv193, dtype="float32"
                )
                lv195: R.Tensor((1, 3, 224, 224), dtype="float32") = R.astype(
                    lv194, dtype="float32"
                )
                gv: R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")) = (lv195,)
                R.output(gv)
            return gv

    example_args = (torch.randn(1, 3, 112, 112, dtype=torch.float32),)
    verify_model(InterpolateBilinear(), example_args, {}, expected_bilinear)
    verify_model(InterpolateNearest(), example_args, {}, expected_nearest)
    verify_model(InterpolateBicubic(), example_args, {}, expected_bicubic)


def test_mean():
    class Mean(Module):
        def forward(self, input):
            return input.mean(-1)

    class MeanKeepDim(Module):
        def forward(self, input: torch.Tensor):
            return input.mean(-1, keepdim=True)

    @I.ir_module
    class Expected1:
        @R.function
        def main(
            inp_0: R.Tensor((256, 256), dtype="float32")
        ) -> R.Tuple(R.Tensor((256,), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((256,), dtype="float32") = R.mean(inp_0, axis=[-1], keepdims=False)
                gv: R.Tuple(R.Tensor((256,), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    @I.ir_module
    class Expected2:
        @R.function
        def main(
            inp_0: R.Tensor((256, 256), dtype="float32")
        ) -> R.Tuple(R.Tensor((256, 1), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((256, 1), dtype="float32") = R.mean(inp_0, axis=[-1], keepdims=True)
                gv: R.Tuple(R.Tensor((256, 1), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (torch.randn(256, 256, dtype=torch.float32),)
    verify_model(Mean(), example_args, {}, Expected1)
    verify_model(MeanKeepDim(), example_args, {}, Expected2)


def test_sum():
    class Sum(Module):
        def forward(self, x):
            return torch.sum(x, (2, 1))

    @tvm.script.ir_module
    class expected1:
        @R.function
        def main(
            inp_0: R.Tensor((1, 2, 3, 4), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 4), dtype="float32")):
            # block 0
            with R.dataflow():
                lv: R.Tensor((1, 4), dtype="float32") = R.sum(inp_0, axis=[2, 1], keepdims=False)
                gv: R.Tuple(R.Tensor((1, 4), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
    verify_model(Sum(), example_args, {}, expected1)


def test_argmax_argmin():
    example_args = (torch.randn(256, 256, dtype=torch.float32),)

    class Argmax1(Module):
        def __init__(self) -> None:
            super().__init__()

        def forward(self, input):
            return torch.argmax(input, dim=-1)

    class Argmax2(Module):
        def __init__(self) -> None:
            super().__init__()

        def forward(self, input):
            return torch.argmax(input, dim=-1, keepdim=True)

    @tvm.script.ir_module
    class expected_argmax1:
        @R.function
        def main(
            inp_0: R.Tensor((256, 256), dtype="float32")
        ) -> R.Tuple(R.Tensor((256,), dtype="int64")):
            with R.dataflow():
                lv: R.Tensor((256,), dtype="int64") = R.argmax(inp_0, axis=-1, keepdims=False)
                gv: R.Tuple(R.Tensor((256,), dtype="int64")) = (lv,)
                R.output(gv)
            return gv

    @tvm.script.ir_module
    class expected_argmax2:
        @R.function
        def main(
            inp_0: R.Tensor((256, 256), dtype="float32")
        ) -> R.Tuple(R.Tensor((256, 1), dtype="int64")):
            with R.dataflow():
                lv: R.Tensor((256, 1), dtype="int64") = R.argmax(inp_0, axis=-1, keepdims=True)
                gv: R.Tuple(R.Tensor((256, 1), dtype="int64")) = (lv,)
                R.output(gv)
            return gv

    verify_model(Argmax1(), example_args, {}, expected_argmax1)
    verify_model(Argmax2(), example_args, {}, expected_argmax2)

    class Argmin1(Module):
        def __init__(self) -> None:
            super().__init__()

        def forward(self, input):
            return torch.argmin(input)

    class Argmin2(Module):
        def __init__(self) -> None:
            super().__init__()

        def forward(self, input):
            return torch.argmin(input, keepdim=True)

    @tvm.script.ir_module
    class expected_argmin1:
        @R.function
        def main(
            inp_0: R.Tensor((256, 256), dtype="float32")
        ) -> R.Tuple(R.Tensor((), dtype="int64")):
            with R.dataflow():
                lv: R.Tensor((), dtype="int64") = R.argmin(inp_0, axis=None, keepdims=False)
                gv: R.Tuple(R.Tensor((), dtype="int64")) = (lv,)
                R.output(gv)
            return gv

    @tvm.script.ir_module
    class expected_argmin2:
        @R.function
        def main(
            inp_0: R.Tensor((256, 256), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 1), dtype="int64")):
            with R.dataflow():
                lv: R.Tensor((1, 1), dtype="int64") = R.argmin(inp_0, axis=None, keepdims=True)
                gv: R.Tuple(R.Tensor((1, 1), dtype="int64")) = (lv,)
                R.output(gv)
            return gv

    verify_model(Argmin1(), example_args, {}, expected_argmin1)
    verify_model(Argmin2(), example_args, {}, expected_argmin2)


def test_cat_concat():
    class Cat0(Module):
        def forward(self, x, y):
            return torch.cat((x, y))

    class Cat1(Module):
        def forward(self, x, y):
            return torch.cat((x, y), dim=1)

    class Cat2(Module):
        def forward(self, x, y):
            return torch.cat((x, y), 1)

    class Cat3(Module):
        def forward(self, x, y):
            return torch.concat((x, y), dim=0)

    @I.ir_module
    class Expected1:
        @R.function
        def main(
            inp_0: R.Tensor((2, 3), dtype="float32"),
            inp_1: R.Tensor((2, 3), dtype="float32"),
        ) -> R.Tuple(R.Tensor((4, 3), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((4, 3), dtype="float32") = R.concat((inp_0, inp_1), axis=0)
                gv: R.Tuple(R.Tensor((4, 3), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    @I.ir_module
    class Expected2:
        @R.function
        def main(
            inp_0: R.Tensor((2, 3), dtype="float32"),
            inp_1: R.Tensor((2, 3), dtype="float32"),
        ) -> R.Tuple(R.Tensor((2, 6), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((2, 6), dtype="float32") = R.concat((inp_0, inp_1), axis=1)
                gv: R.Tuple(R.Tensor((2, 6), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (torch.randn(2, 3, dtype=torch.float32), torch.randn(2, 3, dtype=torch.float32))
    verify_model(Cat0(), example_args, {}, Expected1)
    verify_model(Cat1(), example_args, {}, Expected2)
    verify_model(Cat2(), example_args, {}, Expected2)
    verify_model(Cat3(), example_args, {}, Expected1)


def test_cumsum():
    class Cumsum(Module):
        def forward(self, input):
            return torch.cumsum(input, dim=1, dtype=torch.int32)

    @tvm.script.ir_module
    class expected1:
        @R.function
        def main(
            input_1: R.Tensor((1, 2, 3, 4), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="int32")):
            # block 0
            with R.dataflow():
                lv: R.Tensor((1, 2, 3, 4), dtype="int32") = R.cumsum(input_1, axis=1, dtype="int32")
                gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="int32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
    verify_model(Cumsum(), example_args, {}, expected1)


def test_expand():
    class Expand1(Module):
        def forward(self, x):
            return x.expand(4, 2, 3, 4)

    class Expand2(Module):
        def forward(self, x):
            return x.expand(4, -1, -1, 4)

    @tvm.script.ir_module
    class expected1:
        @R.function
        def main(
            x: R.Tensor((1, 2, 3, 4), dtype="float32")
        ) -> R.Tuple(R.Tensor((4, 2, 3, 4), dtype="float32")):
            # block 0
            with R.dataflow():
                lv: R.Tensor((4, 2, 3, 4), dtype="float32") = R.broadcast_to(x, (4, 2, 3, 4))
                gv: R.Tuple(R.Tensor((4, 2, 3, 4), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
    verify_model(Expand1(), example_args, {}, expected1)
    verify_model(Expand2(), example_args, {}, expected1)


def test_flatten():
    class Flatten(Module):
        def __init__(self):
            super().__init__()
            self.f = torch.nn.Flatten(2, -1)

        def forward(self, input):
            return self.f(input)

    @tvm.script.ir_module
    class expected1:
        @R.function
        def main(
            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 100), dtype="float32")):
            # block 0
            with R.dataflow():
                lv: R.Tensor((1, 3, 100), dtype="float32") = R.reshape(input_1, (1, 3, 100))
                gv: R.Tuple(R.Tensor((1, 3, 100), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
    verify_model(Flatten(), example_args, {}, expected1)


def test_meshgrid():
    class Meshgrid1(Module):
        def forward(self, input1, input2):
            return torch.meshgrid((input1, input2), indexing="ij")

    class Meshgrid2(Module):
        def forward(self, input1, input2):
            return torch.meshgrid((input1, input2), indexing="xy")

    @tvm.script.ir_module
    class expected1:
        @R.function
        def main(
            input1: R.Tensor((3,), dtype="float32"), input2: R.Tensor((3,), dtype="float32")
        ) -> R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((3, 1), dtype="float32") = R.reshape(input1, R.shape([3, 1]))
                lv1: R.Tensor((3, 3), dtype="float32") = R.broadcast_to(lv, R.shape([3, 3]))
                lv2: R.Tensor((1, 3), dtype="float32") = R.reshape(input2, R.shape([1, 3]))
                lv3: R.Tensor((3, 3), dtype="float32") = R.broadcast_to(lv2, R.shape([3, 3]))
                gv: R.Tuple(
                    R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")
                ) = (lv1, lv3)
                R.output(gv)
            return gv

    @tvm.script.ir_module
    class expected2:
        @R.function
        def main(
            input1: R.Tensor((3,), dtype="float32"), input2: R.Tensor((3,), dtype="float32")
        ) -> R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((3, 1), dtype="float32") = R.reshape(input2, R.shape([3, 1]))
                lv1: R.Tensor((3, 3), dtype="float32") = R.broadcast_to(lv, R.shape([3, 3]))
                lv2: R.Tensor((1, 3), dtype="float32") = R.reshape(input1, R.shape([1, 3]))
                lv3: R.Tensor((3, 3), dtype="float32") = R.broadcast_to(lv2, R.shape([3, 3]))
                gv: R.Tuple(
                    R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")
                ) = (lv3, lv1)
                R.output(gv)
            return gv

    example_args = (
        torch.randn(3, dtype=torch.float32),
        torch.randn(3, dtype=torch.float32),
    )
    verify_model(Meshgrid1(), example_args, {}, expected1)
    verify_model(Meshgrid2(), example_args, {}, expected2)


def test_permute():
    class Permute1(Module):
        def forward(self, x):
            return x.permute(0, 3, 2, 1)

    class Permute2(Module):
        def forward(self, x):
            return torch.permute(x, (0, 3, 2, 1))

    @tvm.script.ir_module
    class expected1:
        @R.function
        def main(
            x: R.Tensor((1, 2, 3, 4), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 4, 3, 2), dtype="float32")):
            # block 0
            with R.dataflow():
                lv: R.Tensor((1, 4, 3, 2), dtype="float32") = R.permute_dims(x, axes=[0, 3, 2, 1])
                gv: R.Tuple(R.Tensor((1, 4, 3, 2), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
    verify_model(Permute1(), example_args, {}, expected1)
    verify_model(Permute2(), example_args, {}, expected1)


def test_repeat():
    class Tile1(Module):
        def forward(self, x: torch.Tensor):
            return x.repeat(2)

    class Tile2(Module):
        def forward(self, x: torch.Tensor):
            return x.repeat(4, 2)

    @tvm.script.ir_module
    class expected1:
        @R.function
        def main(x: R.Tensor((3,), dtype="float32")) -> R.Tuple(R.Tensor((6,), dtype="float32")):
            # block 0
            with R.dataflow():
                lv: R.Tensor((6,), dtype="float32") = R.tile(x, 2)
                gv: R.Tuple(R.Tensor((6,), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    @tvm.script.ir_module
    class expected2:
        @R.function
        def main(
            x: R.Tensor((1, 3), dtype="float32")
        ) -> R.Tuple(R.Tensor((4, 6), dtype="float32")):
            # block 0
            with R.dataflow():
                lv: R.Tensor((4, 6), dtype="float32") = R.tile(x, [4, 2])
                gv: R.Tuple(R.Tensor((4, 6), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (torch.randn(3, dtype=torch.float32),)
    verify_model(Tile1(), example_args, {}, expected1)

    example_args = (torch.randn(1, 3, dtype=torch.float32),)
    verify_model(Tile2(), example_args, {}, expected2)

    example_args = (torch.randn(1, 3, dtype=torch.float32),)
    verify_model(Tile2(), example_args, {}, expected2)


def test_reshape():
    class Reshape(Module):
        def forward(self, x):
            return x.reshape(2, 12)

    @tvm.script.ir_module
    class expected1:
        @R.function
        def main(
            x: R.Tensor((1, 2, 3, 4), dtype="float32")
        ) -> R.Tuple(R.Tensor((2, 12), dtype="float32")):
            # block 0
            with R.dataflow():
                lv: R.Tensor((2, 12), dtype="float32") = R.reshape(x, (2, 12))
                gv: R.Tuple(R.Tensor((2, 12), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
    verify_model(Reshape(), example_args, {}, expected1)


def test_reshape_as():
    class ReshapeAs(Module):
        def forward(self, x: torch.Tensor, y: torch.Tensor):
            return x.reshape_as(y)

    @tvm.script.ir_module
    class expected1:
        @R.function
        def main(
            x: R.Tensor((1, 2, 3, 4), dtype="float32"),
            y: R.Tensor((2, 12), dtype="float32"),
        ) -> R.Tuple(R.Tensor((2, 12), dtype="float32")):
            # block 0
            with R.dataflow():
                lv: R.Tensor((2, 12), dtype="float32") = R.reshape(x, (2, 12))
                gv: R.Tuple(R.Tensor((2, 12), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (
        torch.randn(1, 2, 3, 4, dtype=torch.float32),
        torch.randn(2, 12, dtype=torch.float32),
    )
    verify_model(ReshapeAs(), example_args, {}, expected1)


def test_roll():
    class Roll1(Module):
        def forward(self, x):
            return torch.roll(x, 1)

    class Roll2(Module):
        def forward(self, x):
            return torch.roll(x, -1, 0)

    class Roll3(Module):
        def forward(self, x):
            return torch.roll(x, shifts=(2, 1), dims=(0, 1))

    # Test case 1: torch.roll(x, 1)
    @I.ir_module
    class Expected1:
        @R.function
        def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype="int64")):
            with R.dataflow():
                lv: R.Tensor((8,), dtype="int64") = R.reshape(x, R.shape([8]))
                lv1: R.Tensor((8,), dtype="int64") = R.arange(
                    R.prim_value(0), R.prim_value(8), R.prim_value(1), dtype="int64"
                )
                lv2: R.Tensor((8,), dtype="int64") = R.add(lv1, R.const(7, "int64"))
                lv3: R.Tensor((8,), dtype="int64") = R.mod(lv2, R.const(8, "int64"))
                lv4: R.Tensor((8,), dtype="int64") = R.take(lv, lv3, axis=0, mode="fast")
                lv5: R.Tensor((4, 2), dtype="int64") = R.reshape(lv4, R.shape([4, 2]))
                gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv5,)
                R.output(gv)
            return gv

    # Test case 2: torch.roll(x, -1, 0)
    @I.ir_module
    class Expected2:
        @R.function
        def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype="int64")):
            with R.dataflow():
                lv: R.Tensor((4,), dtype="int64") = R.arange(
                    R.prim_value(0), R.prim_value(4), R.prim_value(1), dtype="int64"
                )
                lv1: R.Tensor((4,), dtype="int64") = R.add(lv, R.const(1, "int64"))
                lv2: R.Tensor((4,), dtype="int64") = R.mod(lv1, R.const(4, "int64"))
                lv3: R.Tensor((4, 2), dtype="int64") = R.take(x, lv2, axis=0, mode="fast")
                gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv3,)
                R.output(gv)
            return gv

    # Test case 3: torch.roll(x, shifts=(2,1), dims=(0,1))
    @I.ir_module
    class Expected3:
        @R.function
        def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype="int64")):
            with R.dataflow():
                # First roll along dim=0 with shift=2
                lv: R.Tensor((4,), dtype="int64") = R.arange(
                    R.prim_value(0), R.prim_value(4), R.prim_value(1), dtype="int64"
                )
                lv1: R.Tensor((4,), dtype="int64") = R.add(lv, R.const(2, "int64"))
                lv2: R.Tensor((4,), dtype="int64") = R.mod(lv1, R.const(4, "int64"))
                lv3: R.Tensor((4, 2), dtype="int64") = R.take(x, lv2, axis=0, mode="fast")
                # Second roll along dim=1 with shift=1
                lv4: R.Tensor((2,), dtype="int64") = R.arange(
                    R.prim_value(0), R.prim_value(2), R.prim_value(1), dtype="int64"
                )
                lv5: R.Tensor((2,), dtype="int64") = R.add(lv4, R.const(1, "int64"))
                lv6: R.Tensor((2,), dtype="int64") = R.mod(lv5, R.const(2, "int64"))
                lv7: R.Tensor((4, 2), dtype="int64") = R.take(lv3, lv6, axis=1, mode="fast")
                gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv7,)
                R.output(gv)
            return gv

    # Test inputs
    example_input = torch.randint(0, 10, (4, 2), dtype=torch.int64)

    # Run verification for each case
    verify_model(Roll1(), (example_input,), {}, Expected1)
    verify_model(Roll2(), (example_input,), {}, Expected2)
    verify_model(Roll3(), (example_input,), {}, Expected3)


def test_select_slice():
    class Slice1(Module):
        def forward(self, x):
            return x[0, 1::2, :, :3]

    @tvm.script.ir_module
    class expected1:
        @R.function
        def main(
            x: R.Tensor((1, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 10, 3), dtype="float32")):
            # block 0
            with R.dataflow():
                lv: R.Tensor((3, 10, 10), dtype="float32") = R.take(x, R.const(0, "int64"), axis=0)
                lv1: R.Tensor((1, 10, 10), dtype="float32") = R.strided_slice(
                    lv,
                    (R.prim_value(0),),
                    (R.prim_value(1),),
                    (R.prim_value(9223372036854775807),),
                    (R.prim_value(2),),
                    assume_inbound=False,
                )
                lv2: R.Tensor((1, 10, 10), dtype="float32") = R.strided_slice(
                    lv1,
                    (R.prim_value(1),),
                    (R.prim_value(0),),
                    (R.prim_value(9223372036854775807),),
                    (R.prim_value(1),),
                    assume_inbound=False,
                )
                lv3: R.Tensor((1, 10, 3), dtype="float32") = R.strided_slice(
                    lv2,
                    (R.prim_value(2),),
                    (R.prim_value(0),),
                    (R.prim_value(3),),
                    (R.prim_value(1),),
                    assume_inbound=False,
                )
                gv: R.Tuple(R.Tensor((1, 10, 3), dtype="float32")) = (lv3,)
                R.output(gv)
            return gv

    class Slice2(Module):
        def forward(self, x):
            return x[:, None, None, :, None]

    @I.ir_module
    class expected2:
        @R.function
        def main(
            x: R.Tensor((8, 16), dtype="float32")
        ) -> R.Tuple(R.Tensor((8, 1, 1, 16, 1), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((8, 16), dtype="float32") = R.strided_slice(
                    x,
                    (R.prim_value(0),),
                    (R.prim_value(0),),
                    (R.prim_value(9223372036854775807),),
                    (R.prim_value(1),),
                    assume_inbound=False,
                )
                lv1: R.Tensor((8, 1, 16), dtype="float32") = R.expand_dims(lv, axis=[1])
                lv2: R.Tensor((8, 1, 1, 16), dtype="float32") = R.expand_dims(lv1, axis=[2])
                lv3: R.Tensor((8, 1, 1, 16), dtype="float32") = R.strided_slice(
                    lv2,
                    (R.prim_value(3),),
                    (R.prim_value(0),),
                    (R.prim_value(9223372036854775807),),
                    (R.prim_value(1),),
                    assume_inbound=False,
                )
                lv4: R.Tensor((8, 1, 1, 16, 1), dtype="float32") = R.expand_dims(lv3, axis=[4])
                gv: R.Tuple(R.Tensor((8, 1, 1, 16, 1), dtype="float32")) = (lv4,)
                R.output(gv)
            return gv

    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
    verify_model(Slice1(), example_args, {}, expected1)

    example_args = (torch.randn(8, 16, dtype=torch.float32),)
    verify_model(Slice2(), example_args, {}, expected2)


def test_slice_scatter():
    class SliceScatter1(Module):
        def forward(self, input, src):
            return torch.slice_scatter(input, src, dim=1, start=1, end=7, step=2)

    @tvm.script.ir_module
    class expected1:
        @R.function
        def main(
            a: R.Tensor((8, 8, 10, 10), dtype="float32"),
            b: R.Tensor((8, 3, 10, 10), dtype="float32"),
        ) -> R.Tuple(R.Tensor((8, 8, 10, 10), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((8, 8, 10, 10), dtype="float32") = R.slice_scatter(
                    a, b, R.prim_value(1), R.prim_value(7), R.prim_value(2), axis=1
                )
                gv: R.Tuple(R.Tensor((8, 8, 10, 10), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    class SliceScatter2(Module):
        def forward(self, input, src):
            return torch.slice_scatter(input, src, dim=0, start=0, end=6, step=1)

    @I.ir_module
    class expected2:
        @R.function
        def main(
            a: R.Tensor((8, 16), dtype="float32"), b: R.Tensor((6, 16), dtype="float32")
        ) -> R.Tuple(R.Tensor((8, 16), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((8, 16), dtype="float32") = R.slice_scatter(
                    a, b, R.prim_value(0), R.prim_value(6), R.prim_value(1), axis=0
                )
                gv: R.Tuple(R.Tensor((8, 16), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    class SliceScatterNegative(Module):
        def forward(self, input, src):
            return torch.slice_scatter(input, src, dim=1, start=0, end=-2, step=1)

    @tvm.script.ir_module
    class expected_slice_scatter:
        @R.function
        def main(
            a: R.Tensor((2, 5), dtype="float32"), b: R.Tensor((2, 3), dtype="float32")
        ) -> R.Tuple(R.Tensor((2, 5), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((2, 5), dtype="float32") = R.slice_scatter(
                    a, b, R.prim_value(0), R.prim_value(3), R.prim_value(1), axis=1
                )
                gv: R.Tuple(R.Tensor((2, 5), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (torch.randn(8, 8, 10, 10, dtype=torch.float32), torch.randn(8, 3, 10, 10))
    verify_model(SliceScatter1(), example_args, {}, expected1)

    example_args = (torch.randn(8, 16, dtype=torch.float32), torch.randn(6, 16))
    verify_model(SliceScatter2(), example_args, {}, expected2)

    example_args = (torch.randn(2, 5, dtype=torch.float32), torch.randn(2, 3, dtype=torch.float32))
    verify_model(SliceScatterNegative(), example_args, {}, expected_slice_scatter)


def test_split():
    class Chunk(Module):
        def forward(self, input):
            return torch.chunk(input, 3, dim=1)

    @tvm.script.ir_module
    class Expected:
        @R.function
        def main(
            input: R.Tensor((1, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(
            R.Tensor((1, 1, 10, 10), dtype="float32"),
            R.Tensor((1, 1, 10, 10), dtype="float32"),
            R.Tensor((1, 1, 10, 10), dtype="float32"),
        ):
            # block 0
            with R.dataflow():
                lv: R.Tuple(
                    R.Tensor((1, 1, 10, 10), dtype="float32"),
                    R.Tensor((1, 1, 10, 10), dtype="float32"),
                    R.Tensor((1, 1, 10, 10), dtype="float32"),
                ) = R.split(input, indices_or_sections=[1, 2], axis=1)
                lv1: R.Tensor((1, 1, 10, 10), dtype="float32") = lv[0]
                lv2: R.Tensor((1, 1, 10, 10), dtype="float32") = lv[1]
                lv3: R.Tensor((1, 1, 10, 10), dtype="float32") = lv[2]
                gv: R.Tuple(
                    R.Tensor((1, 1, 10, 10), dtype="float32"),
                    R.Tensor((1, 1, 10, 10), dtype="float32"),
                    R.Tensor((1, 1, 10, 10), dtype="float32"),
                ) = (lv1, lv2, lv3)
                R.output(gv)
            return gv

    class Unbind1(Module):
        def forward(self, data):
            return torch.unbind(data)

    @tvm.script.ir_module
    class expected1:
        @R.function
        def main(
            data: R.Tensor((3, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(
            R.Tensor((3, 10, 10), dtype="float32"),
            R.Tensor((3, 10, 10), dtype="float32"),
            R.Tensor((3, 10, 10), dtype="float32"),
        ):
            # block 0
            with R.dataflow():
                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.strided_slice(
                    data,
                    (R.prim_value(0),),
                    (R.prim_value(0),),
                    (R.prim_value(1),),
                    (R.prim_value(1),),
                    assume_inbound=False,
                )
                lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.strided_slice(
                    data,
                    (R.prim_value(0),),
                    (R.prim_value(1),),
                    (R.prim_value(2),),
                    (R.prim_value(1),),
                    assume_inbound=False,
                )
                lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.strided_slice(
                    data,
                    (R.prim_value(0),),
                    (R.prim_value(2),),
                    (R.prim_value(3),),
                    (R.prim_value(1),),
                    assume_inbound=False,
                )
                lv3: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv, axis=[0])
                lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[0])
                lv5: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv2, axis=[0])
                gv: R.Tuple(
                    R.Tensor((3, 10, 10), dtype="float32"),
                    R.Tensor((3, 10, 10), dtype="float32"),
                    R.Tensor((3, 10, 10), dtype="float32"),
                ) = (lv3, lv4, lv5)
                R.output(gv)
            return gv

    class Unbind2(Module):
        def forward(self, data):
            return torch.unbind(data, dim=1)

    @tvm.script.ir_module
    class expected2:
        @R.function
        def main(
            data: R.Tensor((3, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(
            R.Tensor((3, 10, 10), dtype="float32"),
            R.Tensor((3, 10, 10), dtype="float32"),
            R.Tensor((3, 10, 10), dtype="float32"),
        ):
            # block 0
            with R.dataflow():
                lv: R.Tensor((3, 1, 10, 10), dtype="float32") = R.strided_slice(
                    data,
                    (R.prim_value(1),),
                    (R.prim_value(0),),
                    (R.prim_value(1),),
                    (R.prim_value(1),),
                    assume_inbound=False,
                )
                lv1: R.Tensor((3, 1, 10, 10), dtype="float32") = R.strided_slice(
                    data,
                    (R.prim_value(1),),
                    (R.prim_value(1),),
                    (R.prim_value(2),),
                    (R.prim_value(1),),
                    assume_inbound=False,
                )
                lv2: R.Tensor((3, 1, 10, 10), dtype="float32") = R.strided_slice(
                    data,
                    (R.prim_value(1),),
                    (R.prim_value(2),),
                    (R.prim_value(3),),
                    (R.prim_value(1),),
                    assume_inbound=False,
                )
                lv3: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv, axis=[1])
                lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[1])
                lv5: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv2, axis=[1])
                gv: R.Tuple(
                    R.Tensor((3, 10, 10), dtype="float32"),
                    R.Tensor((3, 10, 10), dtype="float32"),
                    R.Tensor((3, 10, 10), dtype="float32"),
                ) = (lv3, lv4, lv5)
                R.output(gv)
            return gv

    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
    verify_model(Chunk(), example_args, {}, Expected)

    example_args = (torch.randn(3, 3, 10, 10, dtype=torch.float32),)
    verify_model(Unbind1(), example_args, {}, expected1)
    verify_model(Unbind2(), example_args, {}, expected2)


def test_squeeze():
    class Squeeze1(Module):
        def forward(self, input):
            return input.squeeze(1)

    @tvm.script.ir_module
    class Expected1:
        @R.function
        def main(
            inp_0: R.Tensor((3, 1, 4, 1), dtype="float32")
        ) -> R.Tuple(R.Tensor((3, 4, 1), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((3, 4, 1), dtype="float32") = R.squeeze(inp_0, axis=[1])
                gv: R.Tuple(R.Tensor((3, 4, 1), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    class Squeeze2(Module):
        def forward(self, input):
            return input.squeeze()

    @tvm.script.ir_module
    class Expected2:
        @R.function
        def main(
            input: R.Tensor((3, 1, 4, 1), dtype="float32")
        ) -> R.Tuple(R.Tensor((3, 4), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((3, 4), dtype="float32") = R.squeeze(input, axis=[1, 3])
                gv: R.Tuple(R.Tensor((3, 4), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (torch.randn(3, 1, 4, 1, dtype=torch.float32),)

    verify_model(Squeeze1(), example_args, {}, Expected1)
    verify_model(Squeeze2(), example_args, {}, Expected2)


def test_stack():
    class Stack0(Module):
        def forward(self, x, y):
            return torch.stack((x, y))  # default dim=0

    class Stack1(Module):
        def forward(self, x, y):
            return torch.stack((x, y), dim=1)

    class Stack2(Module):
        def forward(self, x, y):
            return torch.stack((x, y), 1)  # positional dim

    class Stack3(Module):
        def forward(self, x, y):
            return torch.stack((x, y), dim=-1)  # negative dim

    @I.ir_module
    class Expected0:
        @R.function
        def main(
            x: R.Tensor((2, 3), dtype="float32"),
            y: R.Tensor((2, 3), dtype="float32"),
        ) -> R.Tuple(R.Tensor((2, 2, 3), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((4, 3), dtype="float32") = R.concat((x, y), axis=0)
                lv1: R.Tensor((2, 2, 3), dtype="float32") = R.reshape(lv, R.shape([2, 2, 3]))
                gv: R.Tuple(R.Tensor((2, 2, 3), dtype="float32")) = (lv1,)
                R.output(gv)
            return gv

    @I.ir_module
    class Expected1:
        @R.function
        def main(
            x: R.Tensor((2, 3), dtype="float32"),
            y: R.Tensor((2, 3), dtype="float32"),
        ) -> R.Tuple(R.Tensor((2, 2, 3), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((2, 6), dtype="float32") = R.concat((x, y), axis=1)
                lv1: R.Tensor((2, 2, 3), dtype="float32") = R.reshape(lv, R.shape([2, 2, 3]))
                gv: R.Tuple(R.Tensor((2, 2, 3), dtype="float32")) = (lv1,)
                R.output(gv)
            return gv

    @I.ir_module
    class Expected3:
        @R.function
        def main(
            x: R.Tensor((2, 3), dtype="float32"),
            y: R.Tensor((2, 3), dtype="float32"),
        ) -> R.Tuple(R.Tensor((2, 3, 2), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((2, 3, 1), dtype="float32") = R.expand_dims(x, axis=[2])
                lv1: R.Tensor((2, 3, 1), dtype="float32") = R.expand_dims(y, axis=[2])
                lv2: R.Tensor((2, 3, 2), dtype="float32") = R.concat((lv, lv1), axis=-1)
                gv: R.Tuple(R.Tensor((2, 3, 2), dtype="float32")) = (lv2,)
                R.output(gv)
            return gv

    example_args = (torch.randn(2, 3, dtype=torch.float32), torch.randn(2, 3, dtype=torch.float32))

    verify_model(Stack0(), example_args, {}, Expected0)
    verify_model(Stack1(), example_args, {}, Expected1)
    verify_model(Stack2(), example_args, {}, Expected1)
    verify_model(Stack3(), example_args, {}, Expected3)


def test_tile():
    class Tile1(Module):
        def forward(self, x):
            return x.tile((2,))

    class Tile2(Module):
        def forward(self, x):
            return x.tile(4, 2)

    class Tile3(Module):
        def forward(self, x):
            return torch.tile(x, (4, 2))

    @tvm.script.ir_module
    class expected1:
        @R.function
        def main(
            x: R.Tensor((1, 3), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 6), dtype="float32")):
            # block 0
            with R.dataflow():
                lv: R.Tensor((1, 6), dtype="float32") = R.tile(x, repeats=[1, 2])
                gv: R.Tuple(R.Tensor((1, 6), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    @tvm.script.ir_module
    class expected2:
        @R.function
        def main(
            x: R.Tensor((1, 3), dtype="float32")
        ) -> R.Tuple(R.Tensor((4, 6), dtype="float32")):
            # block 0
            with R.dataflow():
                lv: R.Tensor((4, 6), dtype="float32") = R.tile(x, repeats=[4, 2])
                gv: R.Tuple(R.Tensor((4, 6), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (torch.randn(1, 3, dtype=torch.float32),)
    verify_model(Tile1(), example_args, {}, expected1)
    verify_model(Tile2(), example_args, {}, expected2)
    verify_model(Tile3(), example_args, {}, expected2)


def test_transpose():
    class Transpose(Module):
        def forward(self, x):
            return x.transpose(1, 3)

    @tvm.script.ir_module
    class expected1:
        @R.function
        def main(
            x: R.Tensor((1, 2, 3, 4), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 4, 3, 2), dtype="float32")):
            # block 0
            with R.dataflow():
                lv: R.Tensor((1, 4, 3, 2), dtype="float32") = R.permute_dims(x, axes=[0, 3, 2, 1])
                gv: R.Tuple(R.Tensor((1, 4, 3, 2), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
    verify_model(Transpose(), example_args, {}, expected1)


def test_unsqueeze():
    class Unsqueeze1(Module):
        def forward(self, input):
            return input.unsqueeze(1)

    @tvm.script.ir_module
    class expected1:
        @R.function
        def main(
            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 1, 3, 10, 10), dtype="float32")):
            # block 0
            with R.dataflow():
                lv: R.Tensor((1, 1, 3, 10, 10), dtype="float32") = R.expand_dims(input_1, 1)
                gv: R.Tuple(R.Tensor((1, 1, 3, 10, 10), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    class Unsqueeze2(Module):
        def forward(self, input):
            return input.unsqueeze(-1)

    @tvm.script.ir_module
    class expected2:
        @R.function
        def main(
            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 3, 10, 10, 1), dtype="float32")):
            # block 0
            with R.dataflow():
                lv: R.Tensor((1, 3, 10, 10, 1), dtype="float32") = R.expand_dims(input_1, -1)
                gv: R.Tuple(R.Tensor((1, 3, 10, 10, 1), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)

    verify_model(Unsqueeze1(), example_args, {}, expected1)
    verify_model(Unsqueeze2(), example_args, {}, expected2)


def test_view():
    class View(Module):
        def forward(self, x):
            return x.view(2, 12)

    @tvm.script.ir_module
    class expected1:
        @R.function
        def main(
            x: R.Tensor((1, 2, 3, 4), dtype="float32")
        ) -> R.Tuple(R.Tensor((2, 12), dtype="float32")):
            # block 0
            with R.dataflow():
                lv: R.Tensor((2, 12), dtype="float32") = R.reshape(x, (2, 12))
                gv: R.Tuple(R.Tensor((2, 12), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
    verify_model(View(), example_args, {}, expected1)


def test_as_strided():
    class AsStrided(Module):
        def forward(self, x):
            return torch.ops.aten.as_strided.default(x, (3, 2, 2), (4, 2, 1))

    @tvm.script.ir_module
    class Expected:
        @R.function
        def main(
            x: R.Tensor((2, 2, 3), dtype="float32")
        ) -> R.Tuple(R.Tensor((3, 2, 2), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((3, 2, 2), dtype="float32") = R.reshape(x, (3, 2, 2))
                gv: R.Tuple(R.Tensor((3, 2, 2), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    class AsStridedNonContiguous(Module):
        def forward(self, x):
            return torch.ops.aten.as_strided.default(x, (2, 2, 2), (6, 3, 1))

    class AsStridedWithStorageOffset(Module):
        def forward(self, x):
            return torch.ops.aten.as_strided.default(x, (2, 2), (2, 1), 1)

    example_args = (torch.randn(2, 2, 3, dtype=torch.float32),)
    verify_model(AsStrided(), example_args, {}, Expected)

    exported = export(AsStridedNonContiguous(), args=example_args)
    with pytest.raises(AssertionError, match="non-contiguous stride"):
        from_exported_program(exported)

    example_args = (torch.randn(2, 2, dtype=torch.float32),)
    exported = export(AsStridedWithStorageOffset(), args=example_args)
    with pytest.raises(AssertionError, match="storage_offset"):
        from_exported_program(exported)


def test_arange():
    class Arange(Module):
        def forward(self, input):
            return torch.arange(0, 20, dtype=torch.int32)

    @tvm.script.ir_module
    class Expected:
        @R.function
        def main(
            input: R.Tensor((10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((20,), dtype="int32")):
            with R.dataflow():
                lv: R.Tensor((20,), dtype="int32") = R.arange(0, 20, 1, dtype="int32")
                gv: R.Tuple(R.Tensor((20,), dtype="int32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (torch.randn(10, 10, dtype=torch.float32),)
    verify_model(Arange(), example_args, {}, Expected)


def test_hamming_window():
    class HammingWindow(Module):
        def forward(self, input):
            return torch.hamming_window(20, True, dtype=torch.float32)

    @tvm.script.ir_module
    class Expected:
        @R.function
        def main(
            input: R.Tensor((10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((20,), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((20,), dtype="float32") = R.hamming_window(
                    R.prim_value(20),
                    R.prim_value(1),
                    R.prim_value(T.float32(0.54000000000000004)),
                    R.prim_value(T.float32(0.46000000000000002)),
                    dtype="float32",
                )
                gv: R.Tuple(R.Tensor((20,), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (torch.randn(10, 10, dtype=torch.float32),)
    verify_model(HammingWindow(), example_args, {}, Expected)


def test_contiguous():
    class Contiguous(Module):
        def forward(self, input):
            return input.contiguous()

    @tvm.script.ir_module
    class Expected:
        @R.function
        def main(
            input: R.Tensor((10, 10), dtype="float32"),
        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
            with R.dataflow():
                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (input,)
                R.output(gv)
            return gv

    example_args = (torch.randn(10, 10, dtype=torch.float32),)
    verify_model(Contiguous(), example_args, {}, Expected)


def test_clone():
    class Clone(Module):
        def forward(self, input):
            return torch.clone(input)

    @tvm.script.ir_module
    class Expected:
        @R.function
        def main(
            input: R.Tensor((10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
            with R.dataflow():
                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (input,)
                R.output(gv)
            return gv

    example_args = (torch.randn(10, 10, dtype=torch.float32),)
    verify_model(Clone(), example_args, {}, Expected)


def test_empty():
    class Empty(Module):
        def forward(self, input):
            return torch.empty((10, 10), dtype=torch.float32)

    @tvm.script.ir_module
    class Expected:
        @R.function
        def main(
            input: R.Tensor((10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((10, 10), dtype="float32") = R.zeros(
                    R.shape([10, 10]), dtype="float32"
                )
                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (torch.randn(10, 10, dtype=torch.float32),)
    verify_model(Empty(), example_args, {}, Expected)


def test_empty_without_dtype():
    class EmptyWithoutDtype(Module):
        def forward(self, input):
            return torch.empty((5, 5))

    @tvm.script.ir_module
    class Expected:
        @R.function
        def main(
            input: R.Tensor((10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((5, 5), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((5, 5), dtype="float32") = R.zeros(R.shape([5, 5]), dtype="float32")
                gv: R.Tuple(R.Tensor((5, 5), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (torch.randn(10, 10, dtype=torch.float32),)
    verify_model(EmptyWithoutDtype(), example_args, {}, Expected)


def test_fill():
    class Fill(Module):
        def forward(self, input: torch.Tensor):
            return torch.fill(input, 1.5)

    @tvm.script.ir_module
    class Expected:
        @R.function
        def main(
            input: R.Tensor((10, 10), dtype="float32")
        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((10, 10), dtype="float32") = R.full_like(
                    input, R.const(1.5, "float32"), dtype="void"
                )
                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (torch.randn(10, 10, dtype=torch.float32),)
    verify_model(Fill(), example_args, {}, Expected)


def test_fill_inplace():
    class FillInplace(Module):
        def forward(self, input: torch.Tensor):
            input.fill_(42.0)
            return input

    @tvm.script.ir_module
    class Expected:
        @R.function
        def main(
            input: R.Tensor((2, 3), dtype="float32")
        ) -> R.Tuple(R.Tensor((2, 3), dtype="float32"), R.Tensor((2, 3), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((2, 3), dtype="float32") = R.full_like(
                    input, R.const(42.0, "float32"), dtype="void"
                )
                gv: R.Tuple(
                    R.Tensor((2, 3), dtype="float32"), R.Tensor((2, 3), dtype="float32")
                ) = (lv, lv)
                R.output(gv)
            return gv

    example_args = (torch.randn(2, 3, dtype=torch.float32),)
    verify_model(FillInplace(), example_args, {}, Expected)


def test_masked_fill():
    class Masked_Fill(Module):
        def forward(self, input: torch.Tensor, mask: torch.Tensor):
            return torch.masked_fill(input, mask, 0)

    @tvm.script.ir_module
    class Expected:
        @R.function
        def main(
            input: R.Tensor((128, 128), dtype="float32"), mask: R.Tensor((128, 128), dtype="bool")
        ) -> R.Tuple(R.Tensor((128, 128), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((), dtype="float32") = R.const(0.0, "float32")
                lv1: R.Tensor((128, 128), dtype="float32") = R.where(mask, lv, input)
                gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv1,)
                R.output(gv)
            return gv

    example_args = (torch.randn(128, 128, dtype=torch.float32), torch.rand(128, 128) < 0.5)
    verify_model(Masked_Fill(), example_args, {}, Expected)


def test_masked_fill_inplace():
    class Masked_Fill_Inplace(Module):
        def forward(self, input: torch.Tensor, mask: torch.Tensor):
            return input.masked_fill_(mask, 1.5)

    @tvm.script.ir_module
    class Expected:
        @R.function
        def main(
            input: R.Tensor((128, 128), dtype="float32"), mask: R.Tensor((128, 128), dtype="bool")
        ) -> R.Tuple(R.Tensor((128, 128), dtype="float32"), R.Tensor((128, 128), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((), dtype="float32") = R.const(1.5, "float32")
                lv1: R.Tensor((128, 128), dtype="float32") = R.where(mask, lv, input)
                gv: R.Tuple(
                    R.Tensor((128, 128), dtype="float32"), R.Tensor((128, 128), dtype="float32")
                ) = (lv1, lv1)
                R.output(gv)
            return gv

    example_args = (torch.randn(128, 128, dtype=torch.float32), torch.rand(128, 128) < 0.5)
    verify_model(Masked_Fill_Inplace(), example_args, {}, Expected)


def test_new_ones():
    class NewOnes(Module):
        def forward(self, x):
            return x.new_ones(1, 2, 3)

    @tvm.script.ir_module
    class expected1:
        @R.function
        def main(
            x: R.Tensor((1, 2, 3), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 2, 3), dtype="float32")):
            # block 0
            with R.dataflow():
                lv: R.Tensor((1, 2, 3), dtype="float32") = R.full(
                    (1, 2, 3), R.const(1, "float32"), dtype="float32"
                )
                gv: R.Tuple(R.Tensor((1, 2, 3), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (torch.randn(1, 2, 3, dtype=torch.float32),)
    verify_model(NewOnes(), example_args, {}, expected1)


def test_new_zeros():
    class NewZeros(torch.nn.Module):
        def forward(self, x):
            return x.new_zeros(1, 128, 128)

    @tvm.script.ir_module
    class expected1:
        @R.function
        def main(
            x: R.Tensor((1, 128, 128), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 128, 128), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 128, 128), dtype="float32") = R.full(
                    R.shape([1, 128, 128]), R.const(0, "float32"), dtype="float32"
                )
                gv: R.Tuple(R.Tensor((1, 128, 128), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (torch.randn(1, 128, 128, dtype=torch.float32),)
    verify_model(NewZeros(), example_args, {}, expected1)


def test_copy():
    class CopyBroadcast(Module):
        def forward(self, x, src):
            x.copy_(src)
            return x

    @tvm.script.ir_module
    class expected_copy:
        @R.function
        def main(
            x: R.Tensor((2, 3), dtype="float32"), src: R.Tensor((), dtype="int64")
        ) -> R.Tuple(R.Tensor((2, 3), dtype="float32"), R.Tensor((2, 3), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((), dtype="float32") = R.astype(src, dtype="float32")
                lv1: R.Tensor((2, 3), dtype="float32") = R.broadcast_to(lv, (2, 3))
                gv: R.Tuple(
                    R.Tensor((2, 3), dtype="float32"), R.Tensor((2, 3), dtype="float32")
                ) = (
                    lv1,
                    lv1,
                )
                R.output(gv)
            return gv

    example_args = (torch.zeros(2, 3, dtype=torch.float32), torch.tensor(1, dtype=torch.int64))
    verify_model(CopyBroadcast(), example_args, {}, expected_copy)


def test_to_copy():
    # float
    class ToFloat(Module):
        def forward(self, x):
            return x.float()

    @tvm.script.ir_module
    class expected_float:
        @R.function
        def main(
            x: R.Tensor((1, 2, 3, 4), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")):
            # block 0
            with R.dataflow():
                lv: R.Tensor((1, 2, 3, 4), dtype="float32") = R.astype(x, dtype="float32")
                gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    # half
    class ToHalf(Module):
        def forward(self, x):
            return x.half()

    @tvm.script.ir_module
    class expected_half:
        @R.function
        def main(
            x: R.Tensor((1, 2, 3, 4), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float16")):
            # block 0
            with R.dataflow():
                lv: R.Tensor((1, 2, 3, 4), dtype="float16") = R.astype(x, dtype="float16")
                gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float16")) = (lv,)
                R.output(gv)
            return gv

    # type
    class Type(Module):
        def forward(self, x):
            return x.type(torch.float32)

    @tvm.script.ir_module
    class expected_type:
        @R.function
        def main(
            x: R.Tensor((1, 2, 3, 4), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")):
            # block 0
            with R.dataflow():
                lv: R.Tensor((1, 2, 3, 4), dtype="float32") = R.astype(x, dtype="float32")
                gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    class To1(Module):
        def forward(self, input):
            return input.to(torch.float16)

    @I.ir_module
    class expected_to1:
        @R.function
        def main(
            inp_0: R.Tensor((1, 2, 3, 4), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float16")):
            with R.dataflow():
                lv: R.Tensor((1, 2, 3, 4), dtype="float16") = R.astype(inp_0, dtype="float16")
                gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float16")) = (lv,)
                R.output(gv)
            return gv

    class To2(Module):
        def forward(self, input):
            return input.to("cpu")

    @I.ir_module
    class expected_to2:
        @R.function
        def main(
            inp_0: R.Tensor((1, 2, 3, 4), dtype="float32")
        ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 2, 3, 4), dtype="float32") = R.astype(inp_0, dtype="float32")
                gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
    verify_model(ToFloat(), example_args, {}, expected_float)
    verify_model(ToHalf(), example_args, {}, expected_half)
    verify_model(Type(), example_args, {}, expected_type)
    verify_model(To1(), example_args, {}, expected_to1)
    verify_model(To2(), example_args, {}, expected_to2)


def test_keep_params():
    class Conv2D1(Module):
        def __init__(self):
            super().__init__()
            self.conv = torch.nn.Conv2d(3, 6, 7, bias=True)

        def forward(self, input):
            return self.conv(input)

    @tvm.script.ir_module
    class expected1:
        @R.function
        def main(
            input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
            conv_weight: R.Tensor((6, 3, 7, 7), dtype="float32"),
            conv_bias: R.Tensor((6,), dtype="float32"),
        ) -> R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")):
            R.func_attr({"num_input": 1})
            # block 0
            with R.dataflow():
                lv1: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.conv2d(
                    input_1,
                    conv_weight,
                    strides=[1, 1],
                    padding=[0, 0, 0, 0],
                    dilation=[1, 1],
                    data_layout="NCHW",
                    kernel_layout="OIHW",
                    out_layout="NCHW",
                    out_dtype="float32",
                )
                lv2: R.Tensor((1, 6, 1, 1), dtype="float32") = R.reshape(conv_bias, [1, 6, 1, 1])
                lv3: R.Tensor((1, 6, 4, 4), dtype="float32") = R.add(lv1, lv2)
                gv: R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")) = (lv3,)
                R.output(gv)
            return gv

    from tvm.relax.frontend import detach_params

    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
    model = Conv2D1()
    exported_program = torch.export.export(model, example_args)
    mod = from_exported_program(exported_program, keep_params_as_input=True)
    mod, params = detach_params(mod)
    tvm.ir.assert_structural_equal(mod, expected1)
    func = mod["main"]
    params = params["main"]

    assert len(params) == len(func.params) - 1
    for param_var, param_tensor in zip(func.params[1:], params):
        assert tuple(x.value for x in param_var.struct_info.shape.values) == param_tensor.shape
        assert param_var.struct_info.dtype == param_tensor.dtype

    tvm.testing.assert_allclose(params[0].numpy(), model.conv.weight.detach().detach().numpy())
    tvm.testing.assert_allclose(params[1].numpy(), model.conv.bias.detach().detach().numpy())


def test_unwrap_unit_return_tuple():
    class Identity(Module):
        def __init__(self):
            super().__init__()

        def forward(self, x):
            return (x,)

    @tvm.script.ir_module
    class Expected:
        @R.function
        def main(
            inp_0: R.Tensor((256, 256), dtype="float32")
        ) -> R.Tensor((256, 256), dtype="float32"):
            with R.dataflow():
                gv: R.Tensor((256, 256), dtype="float32") = inp_0
                R.output(gv)
            return gv

    example_args = (torch.randn(256, 256, dtype=torch.float32),)
    exported_program = export(Identity(), args=example_args)
    mod = from_exported_program(exported_program, unwrap_unit_return_tuple=True)
    tvm.ir.assert_structural_equal(mod, Expected)


def test_no_bind_return_tuple():
    class Identity(Module):
        def __init__(self):
            super().__init__()

        def forward(self, x, y):
            return (x, y)

    @tvm.script.ir_module
    class Expected:
        @R.function
        def main(
            inp_0: R.Tensor((256, 256), dtype="float32"),
            inp_1: R.Tensor((256, 256), dtype="float32"),
        ) -> R.Tuple(R.Tensor((256, 256), dtype="float32"), R.Tensor((256, 256), dtype="float32")):
            with R.dataflow():
                gv: R.Tensor((256, 256), dtype="float32") = inp_0
                gv1: R.Tensor((256, 256), dtype="float32") = inp_1
                R.output(gv, gv1)
            return (gv, gv1)

    example_args = (
        torch.randn(256, 256, dtype=torch.float32),
        torch.randn(256, 256, dtype=torch.float32),
    )
    exported_program = export(Identity(), args=example_args)
    mod = from_exported_program(exported_program, no_bind_return_tuple=True)
    tvm.ir.assert_structural_equal(mod, Expected)


def test_empty_like():
    class EmptyLike(Module):
        def forward(self, data):
            return torch.empty_like(data)

    @tvm.script.ir_module
    class Expected:
        @R.function
        def main(
            data: R.Tensor((5,), dtype="float32"),
        ) -> R.Tuple(R.Tensor((5,), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((5,), dtype="float32") = R.zeros(R.shape([5]), dtype="float32")
                gv: R.Tuple(R.Tensor((5,), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (torch.randn(5, dtype=torch.float32),)

    verify_model(EmptyLike(), example_args, {}, Expected)


def test_one_hot():
    class OneHot(Module):
        def forward(self, indices):
            return torch.nn.functional.one_hot(indices, num_classes=10)

    @tvm.script.ir_module
    class Expected:
        @R.function
        def main(
            indices: R.Tensor((5,), dtype="int64"),
        ) -> R.Tuple(R.Tensor((5, 10), dtype="int64")):
            with R.dataflow():
                lv: R.Tensor((10,), dtype="int64") = R.arange(
                    R.prim_value(0), R.prim_value(10), R.prim_value(1), dtype="int64"
                )
                lv1: R.Tensor((5, 1), dtype="int64") = R.expand_dims(indices, axis=[-1])
                lv2: R.Tensor((5, 10), dtype="bool") = R.equal(lv1, lv)
                lv3: R.Tensor((5, 10), dtype="int64") = R.astype(lv2, dtype="int64")
                gv: R.Tuple(R.Tensor((5, 10), dtype="int64")) = (lv3,)
                R.output(gv)
            return gv

    example_args = (torch.randint(0, 10, (5,), dtype=torch.int64),)

    verify_model(OneHot(), example_args, {}, Expected)


def test_ones_like():
    class OnesLike(Module):
        def forward(self, input):
            return torch.ones_like(input)

    @tvm.script.ir_module
    class Expected:
        @R.function
        def main(
            input: R.Tensor((128, 128), dtype="float32")
        ) -> R.Tuple(R.Tensor((128, 128), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((128, 128), dtype="float32") = R.full_like(
                    input, R.const(1, "int32"), dtype="void"
                )
                gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (torch.rand(128, 128, dtype=torch.float32),)

    verify_model(OnesLike(), example_args, {}, Expected)


def test_zero_inplace():
    class ZeroInplace(Module):
        def forward(self, input):
            return input.zero_()

    @tvm.script.ir_module
    class Expected:
        @R.function
        def main(
            input: R.Tensor((128, 128), dtype="float32")
        ) -> R.Tuple(R.Tensor((128, 128), dtype="float32"), R.Tensor((128, 128), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((128, 128), dtype="float32") = R.full_like(
                    input, R.const(0, "int32"), dtype="void"
                )
                gv: R.Tuple(
                    R.Tensor((128, 128), dtype="float32"), R.Tensor((128, 128), dtype="float32")
                ) = (
                    lv,
                    lv,
                )
                R.output(gv)
            return gv

    example_args = (torch.rand(128, 128, dtype=torch.float32),)

    verify_model(ZeroInplace(), example_args, {}, Expected)


def test_zeros():
    class Zeros(Module):
        def forward(self, input):
            return torch.zeros(5, 2)

    @tvm.script.ir_module
    class Expected:
        @R.function
        def main(
            input: R.Tensor((128, 128), dtype="float32")
        ) -> R.Tuple(R.Tensor((5, 2), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((5, 2), dtype="float32") = R.full(
                    R.shape([5, 2]), R.const(0.0, "float32"), dtype="float32"
                )
                gv: R.Tuple(R.Tensor((5, 2), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (torch.rand(128, 128, dtype=torch.float32),)

    verify_model(Zeros(), example_args, {}, Expected)


def test_zeros_like():
    class ZerosLike(Module):
        def forward(self, input):
            return torch.zeros_like(input)

    @tvm.script.ir_module
    class Expected:
        @R.function
        def main(
            input: R.Tensor((128, 128), dtype="float32")
        ) -> R.Tuple(R.Tensor((128, 128), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((128, 128), dtype="float32") = R.full_like(
                    input, R.const(0, "int32"), dtype="void"
                )
                gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (torch.rand(128, 128, dtype=torch.float32),)
    verify_model(ZerosLike(), example_args, {}, Expected)


def test_type_as():
    class TypeAs(Module):
        def forward(self, input, other):
            return input.type_as(other)

    @tvm.script.ir_module
    class Expected:
        @R.function
        def main(
            input: R.Tensor((128, 128), dtype="float32"),
            other: R.Tensor((128, 128), dtype="float16"),
        ) -> R.Tuple(R.Tensor((128, 128), dtype="float16")):
            with R.dataflow():
                lv: R.Tensor((128, 128), dtype="float16") = R.astype(input, dtype="float16")
                gv: R.Tuple(R.Tensor((128, 128), dtype="float16")) = (lv,)
                R.output(gv)
            return gv

    example_args = (
        torch.rand(128, 128, dtype=torch.float32),
        torch.rand(128, 128, dtype=torch.float16),
    )

    verify_model(TypeAs(), example_args, {}, Expected)


def test_select():
    class Select(Module):
        def forward(self, input):
            return torch.select(input, 0, 1)

    @tvm.script.ir_module
    class Expected:
        @R.function
        def main(
            inp_0: R.Tensor((2, 3), dtype="float32"),
        ) -> R.Tuple(R.Tensor((3,), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((3,), dtype="float32") = R.take(inp_0, R.const(1, "int64"), axis=0)
                gv: R.Tuple(R.Tensor((3,), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (torch.randn(2, 3, dtype=torch.float32),)

    verify_model(Select(), example_args, {}, Expected)


def test_unflatten():
    class Unflatten(Module):
        def forward(self, input):
            return torch.ops.aten.unflatten(input, 1, (3, 5))

    class Unflatten1(Module):
        def forward(self, input):
            return torch.ops.aten.unflatten(input, -2, (3, 5))

    @tvm.script.ir_module
    class Expected:
        @R.function
        def main(
            inp_0: R.Tensor((2, 15, 7), dtype="float32"),
        ) -> R.Tuple(R.Tensor((2, 3, 5, 7), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((2, 3, 5, 7), dtype="float32") = R.reshape(inp_0, [2, 3, 5, 7])
                gv: R.Tuple(R.Tensor((2, 3, 5, 7), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (torch.randn(2, 15, 7, dtype=torch.float32),)

    verify_model(Unflatten(), example_args, {}, Expected)
    verify_model(Unflatten1(), example_args, {}, Expected)


def test_gather():
    class Gather0(Module):
        def forward(self, data, indices):
            return torch.gather(data, 0, indices)

    class Gather1(Module):
        def forward(self, data, indices):
            return torch.gather(data, 1, indices)

    class Gather2(Module):
        def forward(self, data, indices):
            return torch.gather(data, -1, indices)

    class Gather3(Module):
        def forward(self, data, indices):
            return torch.gather(data, -2, indices)

    @tvm.script.ir_module
    class Expected0:
        @R.function
        def main(
            inp_0: R.Tensor((2, 3), dtype="float32"),
            inp_1: R.Tensor((2, 3), dtype="int64"),
        ) -> R.Tuple(R.Tensor((2, 3), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((2, 3), dtype="float32") = R.gather_elements(inp_0, inp_1, axis=0)
                gv: R.Tuple(R.Tensor((2, 3), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    @tvm.script.ir_module
    class Expected1:
        @R.function
        def main(
            inp_0: R.Tensor((2, 3), dtype="float32"),
            inp_1: R.Tensor((2, 3), dtype="int64"),
        ) -> R.Tuple(R.Tensor((2, 3), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((2, 3), dtype="float32") = R.gather_elements(inp_0, inp_1, axis=1)
                gv: R.Tuple(R.Tensor((2, 3), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    @tvm.script.ir_module
    class Expected2:
        @R.function
        def main(
            inp_0: R.Tensor((2, 3), dtype="float32"),
            inp_1: R.Tensor((2, 3), dtype="int64"),
        ) -> R.Tuple(R.Tensor((2, 3), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((2, 3), dtype="float32") = R.gather_elements(inp_0, inp_1, axis=-1)
                gv: R.Tuple(R.Tensor((2, 3), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    @tvm.script.ir_module
    class Expected3:
        @R.function
        def main(
            inp_0: R.Tensor((2, 3), dtype="float32"),
            inp_1: R.Tensor((2, 3), dtype="int64"),
        ) -> R.Tuple(R.Tensor((2, 3), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((2, 3), dtype="float32") = R.gather_elements(inp_0, inp_1, axis=-2)
                gv: R.Tuple(R.Tensor((2, 3), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (
        torch.randn(2, 3, dtype=torch.float32),
        torch.randint(0, 3, (2, 3), dtype=torch.int64),
    )

    verify_model(Gather0(), example_args, {}, Expected0)
    verify_model(Gather1(), example_args, {}, Expected1)
    verify_model(Gather2(), example_args, {}, Expected2)
    verify_model(Gather3(), example_args, {}, Expected3)


def test_index_put():
    # Test case 1: 1D input
    class IndexPut1D(Module):
        def forward(self, data, indices_0, values):
            indices_tuple = (indices_0,)
            return data.index_put_(indices_tuple, values, accumulate=False)

    example_args_1d = (
        torch.randn(64, dtype=torch.float32),
        torch.randint(0, 64, (128,), dtype=torch.int64),
        torch.randn(128, dtype=torch.float32),
    )

    @I.ir_module
    class Expected1D:
        @R.function
        def main(
            data: R.Tensor((64,), dtype="float32"),
            indices_0: R.Tensor((128,), dtype="int64"),
            values: R.Tensor((128,), dtype="float32"),
        ) -> R.Tuple(R.Tensor((64,), dtype="float32"), R.Tensor((64,), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((64,), dtype="float32") = R.index_put(
                    data, R.tuple(indices_0), values, accumulate=False
                )
                gv: R.Tuple(R.Tensor((64,), dtype="float32"), R.Tensor((64,), dtype="float32")) = (
                    lv,
                    lv,
                )
                R.output(gv)
            return gv

    # Test case 2: 2D input
    class IndexPut2D(Module):
        def forward(self, data, indices_0, indices_1, values):
            indices_tuple = (indices_0, indices_1)
            return data.index_put_(indices_tuple, values, accumulate=False)

    example_args_2d = (
        torch.randn(32, 64, dtype=torch.float32),
        torch.randint(0, 32, (128,), dtype=torch.int64),
        torch.randint(0, 64, (128,), dtype=torch.int64),
        torch.randn(128, dtype=torch.float32),
    )

    @I.ir_module
    class Expected2D:
        @R.function
        def main(
            data: R.Tensor((32, 64), dtype="float32"),
            indices_0: R.Tensor((128,), dtype="int64"),
            indices_1: R.Tensor((128,), dtype="int64"),
            values: R.Tensor((128,), dtype="float32"),
        ) -> R.Tuple(R.Tensor((32, 64), dtype="float32"), R.Tensor((32, 64), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((32, 64), dtype="float32") = R.index_put(
                    data, R.tuple(indices_0, indices_1), values, accumulate=False
                )
                gv: R.Tuple(
                    R.Tensor((32, 64), dtype="float32"), R.Tensor((32, 64), dtype="float32")
                ) = (lv, lv)
                R.output(gv)
            return gv

    # Test case 3: 3D input
    class IndexPut3D(Module):
        def forward(self, data, indices_0, indices_1, indices_2, values):
            indices_tuple = (indices_0, indices_1, indices_2)
            return data.index_put_(indices_tuple, values, accumulate=False)

    example_args_3d = (
        torch.randn(16, 32, 64, dtype=torch.float32),
        torch.randint(0, 16, (128,), dtype=torch.int64),
        torch.randint(0, 32, (128,), dtype=torch.int64),
        torch.randint(0, 64, (128,), dtype=torch.int64),
        torch.randn(128, dtype=torch.float32),
    )

    @I.ir_module
    class Expected3D:
        @R.function
        def main(
            data: R.Tensor((16, 32, 64), dtype="float32"),
            indices_0: R.Tensor((128,), dtype="int64"),
            indices_1: R.Tensor((128,), dtype="int64"),
            indices_2: R.Tensor((128,), dtype="int64"),
            values: R.Tensor((128,), dtype="float32"),
        ) -> R.Tuple(
            R.Tensor((16, 32, 64), dtype="float32"), R.Tensor((16, 32, 64), dtype="float32")
        ):
            with R.dataflow():
                lv: R.Tensor((16, 32, 64), dtype="float32") = R.index_put(
                    data, R.tuple(indices_0, indices_1, indices_2), values, accumulate=False
                )
                gv: R.Tuple(
                    R.Tensor((16, 32, 64), dtype="float32"), R.Tensor((16, 32, 64), dtype="float32")
                ) = (lv, lv)
                R.output(gv)
            return gv

    # Test case 4: 4D input
    class IndexPut4D(Module):
        def forward(self, data, indices_0, indices_1, indices_2, indices_3, values):
            indices_tuple = (indices_0, indices_1, indices_2, indices_3)
            return data.index_put_(indices_tuple, values, accumulate=False)

    example_args_4d = (
        torch.randn(8, 16, 32, 64, dtype=torch.float32),
        torch.randint(0, 8, (128,), dtype=torch.int64),
        torch.randint(0, 16, (128,), dtype=torch.int64),
        torch.randint(0, 32, (128,), dtype=torch.int64),
        torch.randint(0, 64, (128,), dtype=torch.int64),
        torch.randn(128, dtype=torch.float32),
    )

    @I.ir_module
    class Expected4D:
        @R.function
        def main(
            data: R.Tensor((8, 16, 32, 64), dtype="float32"),
            indices_0: R.Tensor((128,), dtype="int64"),
            indices_1: R.Tensor((128,), dtype="int64"),
            indices_2: R.Tensor((128,), dtype="int64"),
            indices_3: R.Tensor((128,), dtype="int64"),
            values: R.Tensor((128,), dtype="float32"),
        ) -> R.Tuple(
            R.Tensor((8, 16, 32, 64), dtype="float32"),
            R.Tensor((8, 16, 32, 64), dtype="float32"),
        ):
            with R.dataflow():
                lv: R.Tensor((8, 16, 32, 64), dtype="float32") = R.index_put(
                    data,
                    R.tuple(indices_0, indices_1, indices_2, indices_3),
                    values,
                    accumulate=False,
                )
                gv: R.Tuple(
                    R.Tensor((8, 16, 32, 64), dtype="float32"),
                    R.Tensor((8, 16, 32, 64), dtype="float32"),
                ) = (lv, lv)
                R.output(gv)
            return gv

    # Test case 5: 5D input
    class IndexPut5D(Module):
        def forward(self, data, indices_0, indices_1, indices_2, indices_3, indices_4, values):
            indices_tuple = (indices_0, indices_1, indices_2, indices_3, indices_4)
            return data.index_put_(indices_tuple, values, accumulate=False)

    example_args_5d = (
        torch.randn(4, 8, 16, 32, 64, dtype=torch.float32),
        torch.randint(0, 4, (128,), dtype=torch.int64),
        torch.randint(0, 8, (128,), dtype=torch.int64),
        torch.randint(0, 16, (128,), dtype=torch.int64),
        torch.randint(0, 32, (128,), dtype=torch.int64),
        torch.randint(0, 64, (128,), dtype=torch.int64),
        torch.randn(128, dtype=torch.float32),
    )

    @I.ir_module
    class Expected5D:
        @R.function
        def main(
            data: R.Tensor((4, 8, 16, 32, 64), dtype="float32"),
            indices_0: R.Tensor((128,), dtype="int64"),
            indices_1: R.Tensor((128,), dtype="int64"),
            indices_2: R.Tensor((128,), dtype="int64"),
            indices_3: R.Tensor((128,), dtype="int64"),
            indices_4: R.Tensor((128,), dtype="int64"),
            values: R.Tensor((128,), dtype="float32"),
        ) -> R.Tuple(
            R.Tensor((4, 8, 16, 32, 64), dtype="float32"),
            R.Tensor((4, 8, 16, 32, 64), dtype="float32"),
        ):
            with R.dataflow():
                lv: R.Tensor((4, 8, 16, 32, 64), dtype="float32") = R.index_put(
                    data,
                    R.tuple(indices_0, indices_1, indices_2, indices_3, indices_4),
                    values,
                    accumulate=False,
                )
                gv: R.Tuple(
                    R.Tensor((4, 8, 16, 32, 64), dtype="float32"),
                    R.Tensor((4, 8, 16, 32, 64), dtype="float32"),
                ) = (lv, lv)
                R.output(gv)
            return gv

    # Test case 6: 2D input with multi-dimensional index (broadcasting)
    # This tests the multi-dimensional index support with broadcasting
    class IndexPutBroadcast1D(Module):
        def forward(self, data, indices_1):
            indices_0 = torch.arange(data.shape[0]).unsqueeze(1)
            values = torch.ones(data.shape[0], len(indices_1), dtype=data.dtype)
            return data.index_put_((indices_0, indices_1), values, accumulate=False)

    example_args_broadcast1 = (
        torch.randn(32, 64, dtype=torch.float32),
        torch.randint(0, 64, (10,), dtype=torch.int64),
    )

    @I.ir_module
    class ExpectedBroadcast1D:
        @R.function
        def main(
            data: R.Tensor((32, 64), dtype="float32"),
            indices_1: R.Tensor((10,), dtype="int64"),
        ) -> R.Tuple(R.Tensor((32, 64), dtype="float32"), R.Tensor((32, 64), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((32,), dtype="int64") = R.arange(
                    R.prim_value(0), R.prim_value(32), R.prim_value(1), dtype="int64"
                )
                lv1: R.Tensor((32, 1), dtype="int64") = R.expand_dims(lv, axis=[1])
                lv2: R.Tensor((32, 10), dtype="float32") = R.full(
                    R.shape([32, 10]), R.const(1.0, "float32"), dtype="float32"
                )
                lv3: R.Tensor((32, 64), dtype="float32") = R.index_put(
                    data, R.tuple(lv1, indices_1), lv2, accumulate=False
                )
                gv: R.Tuple(
                    R.Tensor((32, 64), dtype="float32"), R.Tensor((32, 64), dtype="float32")
                ) = (lv3, lv3)
                R.output(gv)
            return gv

    # Test case 7: 2D input with multi-dimensional index (second position)
    class IndexPutBroadcast2D(Module):
        def forward(self, data, indices_0):
            indices_1 = torch.arange(data.shape[1]).unsqueeze(1)
            values = torch.ones(len(indices_0), data.shape[1], dtype=data.dtype)
            return data.index_put_((indices_0, indices_1), values, accumulate=False)

    example_args_broadcast2 = (
        torch.randn(32, 64, dtype=torch.float32),
        torch.randint(0, 32, (10,), dtype=torch.int64),
    )

    @I.ir_module
    class ExpectedBroadcast2D:
        @R.function
        def main(
            data: R.Tensor((32, 64), dtype="float32"),
            indices_0: R.Tensor((10,), dtype="int64"),
        ) -> R.Tuple(R.Tensor((32, 64), dtype="float32"), R.Tensor((32, 64), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((64,), dtype="int64") = R.arange(
                    R.prim_value(0), R.prim_value(64), R.prim_value(1), dtype="int64"
                )
                lv1: R.Tensor((64, 1), dtype="int64") = R.expand_dims(lv, axis=[1])
                lv2: R.Tensor((10, 64), dtype="float32") = R.full(
                    R.shape([10, 64]), R.const(1.0, "float32"), dtype="float32"
                )
                lv3: R.Tensor((32, 64), dtype="float32") = R.index_put(
                    data, R.tuple(indices_0, lv1), lv2, accumulate=False
                )
                gv: R.Tuple(
                    R.Tensor((32, 64), dtype="float32"), R.Tensor((32, 64), dtype="float32")
                ) = (lv3, lv3)
                R.output(gv)
            return gv

    # Test case 8: 3D input with mixed 1D and 2D indices
    class IndexPutBroadcast3D(Module):
        def forward(self, data, indices_1):
            indices_0 = torch.arange(data.shape[0]).unsqueeze(1)
            indices_2 = torch.arange(data.shape[2]).unsqueeze(1)
            values = torch.ones(data.shape[0], len(indices_1), data.shape[2], dtype=data.dtype)
            return data.index_put_((indices_0, indices_1, indices_2), values, accumulate=False)

    example_args_broadcast3d = (
        torch.randn(16, 32, 64, dtype=torch.float32),
        torch.randint(0, 32, (10,), dtype=torch.int64),
    )

    @I.ir_module
    class ExpectedBroadcast3D:
        @R.function
        def main(
            data: R.Tensor((16, 32, 64), dtype="float32"),
            indices_1: R.Tensor((10,), dtype="int64"),
        ) -> R.Tuple(
            R.Tensor((16, 32, 64), dtype="float32"), R.Tensor((16, 32, 64), dtype="float32")
        ):
            with R.dataflow():
                lv: R.Tensor((16,), dtype="int64") = R.arange(
                    R.prim_value(0), R.prim_value(16), R.prim_value(1), dtype="int64"
                )
                lv1: R.Tensor((16, 1), dtype="int64") = R.expand_dims(lv, axis=[1])
                lv2: R.Tensor((64,), dtype="int64") = R.arange(
                    R.prim_value(0), R.prim_value(64), R.prim_value(1), dtype="int64"
                )
                lv3: R.Tensor((64, 1), dtype="int64") = R.expand_dims(lv2, axis=[1])
                lv4: R.Tensor((16, 10, 64), dtype="float32") = R.full(
                    R.shape([16, 10, 64]), R.const(1.0, "float32"), dtype="float32"
                )
                lv5: R.Tensor((16, 32, 64), dtype="float32") = R.index_put(
                    data, R.tuple(lv1, indices_1, lv3), lv4, accumulate=False
                )
                gv: R.Tuple(
                    R.Tensor((16, 32, 64), dtype="float32"), R.Tensor((16, 32, 64), dtype="float32")
                ) = (lv5, lv5)
                R.output(gv)
            return gv

    # Run verification for each case
    verify_model(IndexPut1D(), example_args_1d, {}, Expected1D)
    verify_model(IndexPut2D(), example_args_2d, {}, Expected2D)
    verify_model(IndexPut3D(), example_args_3d, {}, Expected3D)
    verify_model(IndexPut4D(), example_args_4d, {}, Expected4D)
    verify_model(IndexPut5D(), example_args_5d, {}, Expected5D)
    verify_model(IndexPutBroadcast1D(), example_args_broadcast1, {}, ExpectedBroadcast1D)
    verify_model(IndexPutBroadcast2D(), example_args_broadcast2, {}, ExpectedBroadcast2D)
    verify_model(IndexPutBroadcast3D(), example_args_broadcast3d, {}, ExpectedBroadcast3D)


def test_flip():
    class Flip0(Module):
        def forward(self, data):
            return torch.flip(data, [0])

    class Flip1(Module):
        def forward(self, data):
            return torch.flip(data, [1])

    @tvm.script.ir_module
    class Expected0:
        @R.function
        def main(
            inp_0: R.Tensor((2, 2), dtype="float32"),
        ) -> R.Tuple(R.Tensor((2, 2), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((2, 2), dtype="float32") = R.flip(inp_0, axis=0)
                gv: R.Tuple(R.Tensor((2, 2), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    @tvm.script.ir_module
    class Expected1:
        @R.function
        def main(
            inp_0: R.Tensor((2, 2), dtype="float32"),
        ) -> R.Tuple(R.Tensor((2, 2), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((2, 2), dtype="float32") = R.flip(inp_0, axis=1)
                gv: R.Tuple(R.Tensor((2, 2), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (torch.randn(2, 2, dtype=torch.float32),)

    verify_model(Flip0(), example_args, {}, Expected0)
    verify_model(Flip1(), example_args, {}, Expected1)


def test_take():
    class Take(Module):
        def forward(self, data, indices):
            return torch.take(data, indices)

    @tvm.script.ir_module
    class Expected:
        @R.function
        def main(
            data: R.Tensor((5,), dtype="float32"),
            indices: R.Tensor((3,), dtype="int64"),
        ) -> R.Tuple(R.Tensor((3,), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((5,), dtype="float32") = R.reshape(data, R.shape([5]))
                lv1: R.Tensor((3,), dtype="float32") = R.take(lv, indices, axis=0, mode="fast")
                gv: R.Tuple(R.Tensor((3,), dtype="float32")) = (lv1,)
                R.output(gv)
            return gv

    example_args = (
        torch.randn(5, dtype=torch.float32),
        torch.randint(0, 5, (3,), dtype=torch.int64),
    )

    verify_model(Take(), example_args, {}, Expected)


def test_std():
    class Std(Module):
        def forward(self, x):
            return torch.std(x)

    @tvm.script.ir_module
    class Expected:
        @R.function
        def main(
            x: R.Tensor((5, 3), dtype="float32"),
        ) -> R.Tuple(R.Tensor((), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((), dtype="float32") = R.variance(x, axis=None, keepdims=False)
                lv1: R.Tensor((), dtype="float32") = R.sqrt(lv)
                gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv1,)
                R.output(gv)
            return gv

    example_args = (torch.randn(5, 3, dtype=torch.float32),)
    verify_model(Std(), example_args, {}, Expected)


def test_var():
    class Var(Module):
        def forward(self, x):
            return torch.var(x)

    @tvm.script.ir_module
    class Expected:
        @R.function
        def main(
            x: R.Tensor((5, 3), dtype="float32"),
        ) -> R.Tuple(R.Tensor((), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((), dtype="float32") = R.variance(x, axis=None, keepdims=False)
                gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (torch.randn(5, 3, dtype=torch.float32),)
    verify_model(Var(), example_args, {}, Expected)


def test_prod():
    class Prod(Module):
        def forward(self, x):
            return torch.prod(x)

    @tvm.script.ir_module
    class Expected:
        @R.function
        def main(
            x: R.Tensor((5, 3), dtype="float32"),
        ) -> R.Tuple(R.Tensor((), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((), dtype="float32") = R.prod(x, axis=None, keepdims=False)
                gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (torch.randn(5, 3, dtype=torch.float32),)
    verify_model(Prod(), example_args, {}, Expected)


def test_cumprod():
    class Cumprod(Module):
        def forward(self, x):
            return torch.cumprod(x, 0)

    @tvm.script.ir_module
    class Expected:
        @R.function
        def main(
            inp_0: R.Tensor((5, 3), dtype="float32"),
        ) -> R.Tuple(R.Tensor((5, 3), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((5, 3), dtype="float32") = R.cumprod(inp_0, axis=0, exclusive=False)
                gv: R.Tuple(R.Tensor((5, 3), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_input = torch.randn(5, 3, dtype=torch.float32)
    verify_model(Cumprod(), (example_input,), {}, Expected)


def test_where():
    class Where(Module):
        def forward(self, condition, x, y):
            return torch.where(condition, x, y)

    @tvm.script.ir_module
    class Expected:
        @R.function
        def main(
            inp_0: R.Tensor((5, 3), dtype="bool"),
            inp_1: R.Tensor((5, 3), dtype="float32"),
            inp_2: R.Tensor((5, 3), dtype="float32"),
        ) -> R.Tuple(R.Tensor((5, 3), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((5, 3), dtype="float32") = R.where(inp_0, inp_1, inp_2)
                gv: R.Tuple(R.Tensor((5, 3), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    condition = torch.randint(0, 2, (5, 3), dtype=torch.bool)
    x = torch.randn(5, 3, dtype=torch.float32)
    y = torch.randn(5, 3, dtype=torch.float32)

    verify_model(Where(), (condition, x, y), {}, Expected)


def test_bucketize():
    class Bucketize(Module):
        def forward(self, input_tensor, boundaries):
            return torch.bucketize(input_tensor, boundaries)

    @tvm.script.ir_module
    class Expected:
        @R.function
        def main(
            input: R.Tensor((20,), dtype="int64"), boundaries: R.Tensor((10,), dtype="int64")
        ) -> R.Tuple(R.Tensor((20,), dtype="int64")):
            with R.dataflow():
                lv: R.Tensor((20,), dtype="int64") = R.bucketize(
                    input, boundaries, out_int32=False, right=False
                )
                gv: R.Tuple(R.Tensor((20,), dtype="int64")) = (lv,)
                R.output(gv)
            return gv

    input_tensor = torch.arange(0, 20)
    boundaries = torch.arange(0, 20, 2)

    verify_model(Bucketize(), (input_tensor, boundaries), {}, Expected)


def test_argsort():
    class Argsort(Module):
        def forward(self, x):
            return torch.argsort(x, dim=1, descending=True)

    @tvm.script.ir_module
    class Expected:
        @R.function
        def main(x: R.Tensor((5, 3), dtype="float32")) -> R.Tuple(R.Tensor((5, 3), dtype="int32")):
            with R.dataflow():
                lv: R.Tensor((5, 3), dtype="int32") = R.argsort(
                    x, axis=1, descending=True, dtype="int32"
                )
                lv1: R.Tensor((5, 3), dtype="float32") = R.gather_elements(x, lv, axis=1)
                lv2: R.Tuple(R.Tensor((5, 3), dtype="float32"), R.Tensor((5, 3), dtype="int32")) = (
                    lv1,
                    lv,
                )
                lv3: R.Tensor((5, 3), dtype="int32") = lv2[1]
                gv: R.Tuple(R.Tensor((5, 3), dtype="int32")) = (lv3,)
                R.output(gv)
            return gv

    example_args = (torch.randn(5, 3, dtype=torch.float32),)
    verify_model(Argsort(), example_args, {}, Expected)


def test_topk():
    class Topk(Module):
        def forward(self, x):
            return torch.topk(x, k=2, dim=1, largest=True, sorted=True)

    @tvm.script.ir_module
    class Expected:
        @R.function
        def main(
            x: R.Tensor((5, 3), dtype="float32")
        ) -> R.Tuple(R.Tensor((5, 2), dtype="float32"), R.Tensor((5, 2), dtype="int64")):
            with R.dataflow():
                lv: R.Tuple(
                    R.Tensor((5, 2), dtype="float32"), R.Tensor((5, 2), dtype="int64")
                ) = R.topk(x, k=2, axis=1, ret_type="both", largest=True, dtype="int64")
                lv1: R.Tensor((5, 2), dtype="float32") = lv[0]
                lv2: R.Tensor((5, 2), dtype="int64") = lv[1]
                gv: R.Tuple(R.Tensor((5, 2), dtype="float32"), R.Tensor((5, 2), dtype="int64")) = (
                    lv1,
                    lv2,
                )
                R.output(gv)
            return gv

    example_args = (torch.randn(5, 3, dtype=torch.float32),)
    verify_model(Topk(), example_args, {}, Expected)


def test_dynamic_shape():
    class DynamicModel(torch.nn.Module):
        def forward(self, x1, x2):
            return torch.ops.aten.add.Tensor(x1, x2)

    B = tvm.tir.SizeVar("BatchSize", dtype="int64")

    @tvm.script.ir_module
    class Expected:
        @R.function
        def main(
            lhs: R.Tensor((B, 4), dtype="float32"),
            rhs: R.Tensor((B, 4), dtype="float32"),
        ) -> R.Tuple(R.Tensor((B, 4), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((B, 4), dtype="float32") = R.add(lhs, rhs)
                gv: R.Tuple(R.Tensor((B, 4), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (torch.randn(2, 4), torch.randn(2, 4))
    batch = torch.export.Dim("batch")
    dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}

    verify_model(
        DynamicModel(),
        example_args,
        {},
        Expected,
        dynamic_shapes=dynamic_shapes,
        run_ep_decomposition=True,
    )


def test_broadcast_to():
    class BroadcastTo(Module):
        def forward(self, x):
            return torch.broadcast_to(x, (5, 3))

    @tvm.script.ir_module
    class Expected:
        @R.function
        def main(
            x: R.Tensor((5, 1), dtype="float32")
        ) -> R.Tuple(R.Tensor((5, 3), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((5, 3), dtype="float32") = R.broadcast_to(x, R.shape([5, 3]))
                gv: R.Tuple(R.Tensor((5, 3), dtype="float32")) = (lv,)
                R.output(gv)

            return gv

    example_args = (torch.randn(5, 1, dtype=torch.float32),)
    verify_model(BroadcastTo(), example_args, {}, Expected)


def test_narrow():
    class Narrow(Module):
        def forward(self, x):
            return torch.narrow(x, 1, 0, 2)

    @tvm.script.ir_module
    class Expected:
        @R.function
        def main(
            x: R.Tensor((5, 3), dtype="float32")
        ) -> R.Tuple(R.Tensor((5, 2), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((5, 2), dtype="float32") = R.strided_slice(
                    x,
                    (R.prim_value(1),),
                    (R.prim_value(0),),
                    (R.prim_value(2),),
                    (R.prim_value(1),),
                    assume_inbound=False,
                )
                gv: R.Tuple(R.Tensor((5, 2), dtype="float32")) = (lv,)
                R.output(gv)

            return gv

    example_args = (torch.randn(5, 3, dtype=torch.float32),)
    verify_model(Narrow(), example_args, {}, Expected)


def test_item():
    class Item(Module):
        def forward(self, x):
            return x.item()

    @tvm.script.ir_module
    class Expected:
        @R.function
        def main(input: R.Tensor((1,), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((), dtype="float32") = R.take(input, R.const(0, "int64"), axis=0)
                gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (torch.randn(1, dtype=torch.float32),)
    verify_model(Item(), example_args, {}, Expected)


def test_norm():
    class Norm(Module):
        def __init__(self, p, dim=None, keepdim=False):
            super().__init__()
            self.p = p
            self.dim = dim
            self.keepdim = keepdim

        def forward(self, x):
            return torch.norm(x, p=self.p, dim=self.dim, keepdim=self.keepdim)

    @tvm.script.ir_module
    class Expected1:
        @R.function
        def main(
            inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"),
        ) -> R.Tuple(R.Tensor((), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((), dtype="float32") = R.max(R.abs(inp_0), axis=None, keepdims=False)
                gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    @tvm.script.ir_module
    class Expected2:
        @R.function
        def main(
            inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"),
        ) -> R.Tuple(R.Tensor((), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((), dtype="float32") = R.min(R.abs(inp_0), axis=None, keepdims=False)
                gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    @tvm.script.ir_module
    class Expected3:
        @R.function
        def main(
            inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"),
        ) -> R.Tuple(R.Tensor((), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 5, 3), dtype="float32") = R.abs(inp_0)
                lv1: R.Tensor((1, 3, 5, 3), dtype="float32") = R.power(lv, R.const(2, "float32"))
                lv2: R.Tensor((), dtype="float32") = R.sum(lv1, axis=None, keepdims=False)
                lv3: R.Tensor((), dtype="float32") = R.power(lv2, R.const(0.5, "float32"))
                gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv3,)
                R.output(gv)
            return gv

    @tvm.script.ir_module
    class Expected4:
        @R.function
        def main(
            inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"),
        ) -> R.Tuple(R.Tensor((), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 5, 3), dtype="float32") = R.abs(inp_0)
                lv1: R.Tensor((1, 3, 5, 3), dtype="float32") = R.power(lv, R.const(1.0, "float32"))
                lv2: R.Tensor((), dtype="float32") = R.sum(lv1, axis=None, keepdims=False)
                lv3: R.Tensor((), dtype="float32") = R.power(lv2, R.const(1.0, "float32"))
                gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv3,)
                R.output(gv)
            return gv

    @tvm.script.ir_module
    class Expected5:
        @R.function
        def main(
            inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"),
        ) -> R.Tuple(R.Tensor((1, 1, 1, 1), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 5, 3), dtype="float32") = R.abs(inp_0)
                lv1: R.Tensor((1, 3, 5, 3), dtype="float32") = R.power(lv, R.const(-4.0, "float32"))
                lv2: R.Tensor((1, 1, 1, 1), dtype="float32") = R.sum(lv1, axis=None, keepdims=True)
                lv3: R.Tensor((1, 1, 1, 1), dtype="float32") = R.power(
                    lv2, R.const(-0.25, "float32")
                )
                gv: R.Tuple(R.Tensor((1, 1, 1, 1), dtype="float32")) = (lv3,)
                R.output(gv)
            return gv

    @tvm.script.ir_module
    class Expected6:
        @R.function
        def main(
            inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"),
        ) -> R.Tuple(R.Tensor((1, 1, 1, 1), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 5, 3), dtype="float32") = R.abs(inp_0)
                lv1: R.Tensor((1, 3, 5, 3), dtype="float32") = R.power(lv, R.const(0.5, "float32"))
                lv2: R.Tensor((1, 1, 1, 1), dtype="float32") = R.sum(lv1, axis=None, keepdims=True)
                lv3: R.Tensor((1, 1, 1, 1), dtype="float32") = R.power(lv2, R.const(2.0, "float32"))
                gv: R.Tuple(R.Tensor((1, 1, 1, 1), dtype="float32")) = (lv3,)
                R.output(gv)
            return gv

    norms = [
        ((float("inf"), None, False), Expected1),
        ((float("-inf"), None, False), Expected2),
        ((float(2), None, False), Expected3),
        ((float(1.0), None, False), Expected4),
        ((float(-4), None, True), Expected5),
        ((float(0.5), None, True), Expected6),
    ]

    example_args = (torch.randn(1, 3, 5, 3, dtype=torch.float32),)

    for (p, dim, keepdim), expected in norms:
        verify_model(Norm(p, dim=dim, keepdim=keepdim), example_args, {}, expected)


def test_eye():
    class Eye1(Module):
        def forward(self, input):
            return torch.eye(3, 5, dtype=torch.float32)

    @tvm.script.ir_module
    class Expected1:
        @R.function
        def main(
            input: R.Tensor((3, 5), dtype="float32")
        ) -> R.Tuple(R.Tensor((3, 5), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((3,), dtype="int64") = R.arange(
                    R.prim_value(0), R.prim_value(3), R.prim_value(1), dtype="int64"
                )
                lv1: R.Tensor((5,), dtype="int64") = R.arange(
                    R.prim_value(0), R.prim_value(5), R.prim_value(1), dtype="int64"
                )
                lv2: R.Tensor((3, 1), dtype="int64") = R.expand_dims(lv, axis=[-1])
                lv3: R.Tensor((3, 5), dtype="bool") = R.equal(lv2, lv1)
                lv4: R.Tensor((1,), dtype="float32") = R.full(
                    R.shape([1]), R.const(1.0, "float32"), dtype="float32"
                )
                lv5: R.Tensor((), dtype="float32") = R.const(0.0, "float32")
                lv6: R.Tensor((3, 5), dtype="float32") = R.where(lv3, lv4, lv5)
                gv: R.Tuple(R.Tensor((3, 5), dtype="float32")) = (lv6,)
                R.output(gv)
            return gv

    class Eye2(Module):
        def forward(self, input):
            return torch.eye(5, dtype=torch.float32)

    @tvm.script.ir_module
    class Expected2:
        @R.function
        def main(
            input: R.Tensor((5,), dtype="float32")
        ) -> R.Tuple(R.Tensor((5, 5), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((5,), dtype="int64") = R.arange(
                    R.prim_value(0), R.prim_value(5), R.prim_value(1), dtype="int64"
                )
                lv1: R.Tensor((5,), dtype="int64") = R.arange(
                    R.prim_value(0), R.prim_value(5), R.prim_value(1), dtype="int64"
                )
                lv2: R.Tensor((5, 1), dtype="int64") = R.expand_dims(lv, axis=[-1])
                lv3: R.Tensor((5, 5), dtype="bool") = R.equal(lv2, lv1)
                lv4: R.Tensor((1,), dtype="float32") = R.full(
                    R.shape([1]), R.const(1.0, "float32"), dtype="float32"
                )
                lv5: R.Tensor((), dtype="float32") = R.const(0.0, "float32")
                lv6: R.Tensor((5, 5), dtype="float32") = R.where(lv3, lv4, lv5)
                gv: R.Tuple(R.Tensor((5, 5), dtype="float32")) = (lv6,)
                R.output(gv)
            return gv

    example_args1 = (torch.randn(3, 5, dtype=torch.float32),)
    verify_model(Eye1(), example_args1, {}, Expected1)

    example_args2 = (torch.randn(5, dtype=torch.float32),)
    verify_model(Eye2(), example_args2, {}, Expected2)


def test_cross_entropy():
    class CrossEntropyModule(Module):
        def __init__(self):
            super().__init__()
            self.criterion = nn.CrossEntropyLoss()
            self.target = torch.tensor([0, 1, 2, 1])

        def forward(self, x):
            return self.criterion(x, self.target)

    @tvm.script.ir_module
    class Expected1:
        @R.function
        def main(x: R.Tensor((4, 3), dtype="float32")) -> R.Tuple(R.Tensor((4,), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((4, 3), dtype="float32") = R.astype(x, dtype="float32")
                lv1: R.Tensor((4, 3), dtype="float32") = R.nn.log_softmax(lv, axis=1)
                lv2: R.Tensor((4,), dtype="bool") = R.not_equal(
                    R.const([0, 1, 2, 1], dtype="int64"), R.const(-100, "int64")
                )
                lv3: R.Tensor((), dtype="int64") = R.const(0, "int64")
                lv4: R.Tensor((4,), dtype="int64") = R.where(
                    lv2, R.const([0, 1, 2, 1], dtype="int64"), lv3
                )
                lv5: R.Tensor((4, 1), dtype="int64") = R.expand_dims(lv4, axis=[1])
                lv6: R.Tensor((4, 1), dtype="float32") = R.gather_elements(lv1, lv5, axis=1)
                lv7: R.Tensor((4,), dtype="float32") = R.squeeze(lv6, axis=[1])
                lv8: R.Tensor((4,), dtype="float32") = R.negative(lv7)
                lv9: R.Tensor((4,), dtype="bool") = R.not_equal(
                    R.const([0, 1, 2, 1], dtype="int64"), R.const(-100, "int64")
                )
                lv10: R.Tensor((), dtype="float32") = R.const(0.0, "float32")
                lv11: R.Tensor((4,), dtype="float32") = R.where(lv9, lv8, lv10)
                lv12: R.Tensor((4,), dtype="bool") = R.not_equal(
                    R.const([0, 1, 2, 1], dtype="int64"), R.const(-100, "int64")
                )
                lv13: R.Tensor((4,), dtype="bool") = R.sum(lv12, axis=[], keepdims=False)
                lv14: R.Tensor((4,), dtype="float32") = R.astype(lv13, dtype="float32")
                lv15: R.Tensor((4,), dtype="float32") = R.sum(lv11, axis=[], keepdims=False)
                lv16: R.Tensor((4,), dtype="float32") = R.divide(lv15, lv14)
                gv: R.Tuple(R.Tensor((4,), dtype="float32")) = (lv16,)
                R.output(gv)
            return gv

    example_args1 = (torch.randn(4, 3, dtype=torch.float32),)
    verify_model(CrossEntropyModule(), example_args1, {}, Expected1)


def test_linspace():
    class Linspace(Module):
        def forward(self, input):
            return torch.linspace(0, 1, steps=9, dtype=torch.float32)

    @tvm.script.ir_module
    class Expected:
        @R.function
        def main(
            input: R.Tensor((9, 9), dtype="float32")
        ) -> R.Tuple(R.Tensor((9,), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((9,), dtype="int64") = R.arange(
                    R.prim_value(0), R.prim_value(9), R.prim_value(1), dtype="int64"
                )
                lv1: R.Tensor((9,), dtype="bool") = R.less(lv, R.const(4, "int64"))
                lv2: R.Tensor((9,), dtype="float32") = R.astype(lv, dtype="float32")
                lv3: R.Tensor((9,), dtype="float32") = R.multiply(lv2, R.const(0.125, "float32"))
                lv4: R.Tensor((9,), dtype="float32") = R.add(lv3, R.const(0.0, "float32"))
                lv5: R.Tensor((9,), dtype="int64") = R.subtract(R.const(8, "int64"), lv)
                lv6: R.Tensor((9,), dtype="float32") = R.astype(lv5, dtype="float32")
                lv7: R.Tensor((9,), dtype="float32") = R.multiply(lv6, R.const(0.125, "float32"))
                lv8: R.Tensor((9,), dtype="float32") = R.subtract(R.const(1.0, "float32"), lv7)
                lv9: R.Tensor((9,), dtype="float32") = R.where(lv1, lv4, lv8)
                gv: R.Tuple(R.Tensor((9,), dtype="float32")) = (lv9,)
                R.output(gv)
            return gv

    example_args = (torch.randn(9, 9, dtype=torch.float32),)
    verify_model(Linspace(), example_args, {}, Expected)


@pytest.mark.parametrize(
    "torch_dtype, relax_dtype",
    [
        (torch.float32, "float32"),
        (torch.float16, "float16"),
        (torch.bfloat16, "bfloat16"),
        (torch.int64, "int64"),
        (torch.int32, "int32"),
        (torch.bool, "bool"),
    ],
)
def test_dtypes(torch_dtype, relax_dtype):
    example_args = (
        torch.randint(0, 10, (10, 10)).to(torch_dtype),
        torch.randint(0, 10, (10, 10)).to(torch_dtype),
    )

    class Model(Module):
        def forward(self, lhs: torch.Tensor, rhs: torch.Tensor):
            return torch.ops.aten.add(lhs, rhs)

    @tvm.script.ir_module
    class Expected:
        @R.function
        def main(
            lhs: R.Tensor((10, 10), dtype=relax_dtype),
            rhs: R.Tensor((10, 10), dtype=relax_dtype),
        ) -> R.Tuple(R.Tensor((10, 10), dtype=relax_dtype)):
            with R.dataflow():
                lv: R.Tensor((10, 10), dtype=relax_dtype) = relax.op.add(lhs, rhs)
                gv: R.Tuple(R.Tensor((10, 10), dtype=relax_dtype)) = (lv,)
                R.output(gv)
            return gv

    verify_model(Model(), example_args, {}, Expected)


def test_mm():
    class MatrixMultiply(Module):
        def forward(self, a, b):
            return torch.mm(a, b)

    example_args = (
        torch.randn(2, 3, dtype=torch.float32),
        torch.randn(3, 4, dtype=torch.float32),
    )

    @tvm.script.ir_module
    class Expected:
        @R.function
        def main(
            a: R.Tensor((2, 3), dtype="float32"),
            b: R.Tensor((3, 4), dtype="float32"),
        ) -> R.Tuple(R.Tensor((2, 4), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((2, 4), dtype="float32") = R.matmul(a, b, out_dtype="float32")
                gv: R.Tuple(R.Tensor((2, 4), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    verify_model(MatrixMultiply(), example_args, {}, Expected)


def test_lstm():
    class BasicLSTM(nn.Module):
        def __init__(self):
            super().__init__()
            self.lstm = nn.LSTM(
                input_size=4,
                hidden_size=8,
                num_layers=1,
                batch_first=True,
                bidirectional=False,
            )

        def forward(self, x):
            y, _ = self.lstm(x)
            return y

    torch.manual_seed(42)
    x = torch.randn(2, 3, 4, dtype=torch.float32)
    model = BasicLSTM()
    with torch.no_grad():
        pytorch_output = model(x)
    exported_program = export(model, args=(x,))
    mod = from_exported_program(exported_program)
    target = tvm.target.Target("llvm")
    ex = relax.build(mod, target)
    vm = relax.VirtualMachine(ex, tvm.cpu())
    x_tvm = tvm.runtime.tensor(x.numpy())
    tvm_output = vm["main"](x_tvm)
    if hasattr(tvm_output, "numpy"):
        tvm_output_np = tvm_output.numpy()
    else:
        tvm_output_np = tvm_output[0].numpy()
    assert (
        pytorch_output.shape == tvm_output_np.shape
    ), f"Shape mismatch: PyTorch {pytorch_output.shape} vs TVM {tvm_output_np.shape}"
    np.testing.assert_allclose(pytorch_output.numpy(), tvm_output_np, rtol=1e-4, atol=1e-5)

    class SeqFirstLSTM(nn.Module):
        def __init__(self):
            super().__init__()
            self.lstm = nn.LSTM(
                input_size=3,
                hidden_size=6,
                num_layers=1,
                batch_first=False,
                bidirectional=False,
            )

        def forward(self, x):
            y, _ = self.lstm(x)
            return y

    torch.manual_seed(43)
    x2 = torch.randn(4, 2, 3, dtype=torch.float32)
    model2 = SeqFirstLSTM()
    with torch.no_grad():
        pytorch_output2 = model2(x2)
    exported_program2 = export(model2, args=(x2,))
    mod2 = from_exported_program(exported_program2)
    ex2 = relax.build(mod2, target)
    vm2 = relax.VirtualMachine(ex2, tvm.cpu())
    x2_tvm = tvm.runtime.tensor(x2.numpy())
    tvm_output2 = vm2["main"](x2_tvm)
    if hasattr(tvm_output2, "numpy"):
        tvm_output2_np = tvm_output2.numpy()
    else:
        tvm_output2_np = tvm_output2[0].numpy()
    assert pytorch_output2.shape == tvm_output2_np.shape
    np.testing.assert_allclose(pytorch_output2.numpy(), tvm_output2_np, rtol=1e-4, atol=1e-5)


def test_tensor_none_tuple():
    example_args = (torch.tensor([1.0, 2.0, 3.0]),)

    class TensorNoneModel(Module):
        def forward(self, x):
            return x + 1, None

    @tvm.script.ir_module
    class Expected:
        @R.function
        def main(
            x: R.Tensor((3,), dtype="float32")
        ) -> R.Tuple(R.Tensor((3,), dtype="float32"), R.Object):
            with R.dataflow():
                lv: R.Tensor((3,), dtype="float32") = R.add(x, R.const(1.0, "float32"))
                gv: R.Tuple(R.Tensor((3,), dtype="float32"), R.Object) = (lv, R.null_value())
                R.output(gv)
            return gv

    verify_model(TensorNoneModel(), example_args, {}, Expected)


def test_gru():
    class BasicGRU(nn.Module):
        def __init__(self):
            super().__init__()
            self.gru = nn.GRU(
                input_size=4,
                hidden_size=8,
                num_layers=1,
                batch_first=True,
                bidirectional=False,
            )

        def forward(self, x):
            y, _ = self.gru(x)
            return y

    torch.manual_seed(42)
    x = torch.randn(2, 3, 4, dtype=torch.float32)
    model = BasicGRU()
    with torch.no_grad():
        pytorch_output = model(x)
    exported_program = export(model, args=(x,))
    mod = from_exported_program(exported_program)
    target = tvm.target.Target("llvm")
    ex = relax.build(mod, target)
    vm = relax.VirtualMachine(ex, tvm.cpu())
    x_tvm = tvm.runtime.tensor(x.numpy())
    tvm_output = vm["main"](x_tvm)
    if hasattr(tvm_output, "numpy"):
        tvm_output_np = tvm_output.numpy()
    else:
        tvm_output_np = tvm_output[0].numpy()
    assert (
        pytorch_output.shape == tvm_output_np.shape
    ), f"Shape mismatch: PyTorch {pytorch_output.shape} vs TVM {tvm_output_np.shape}"
    np.testing.assert_allclose(pytorch_output.numpy(), tvm_output_np, rtol=1e-4, atol=1e-5)

    class SeqFirstGRU(nn.Module):
        def __init__(self):
            super().__init__()
            self.gru = nn.GRU(
                input_size=3,
                hidden_size=6,
                num_layers=1,
                batch_first=False,
                bidirectional=False,
            )

        def forward(self, x):
            y, _ = self.gru(x)
            return y

    torch.manual_seed(43)
    x2 = torch.randn(4, 2, 3, dtype=torch.float32)
    model2 = SeqFirstGRU()
    with torch.no_grad():
        pytorch_output2 = model2(x2)
    exported_program2 = export(model2, args=(x2,))
    mod2 = from_exported_program(exported_program2)
    ex2 = relax.build(mod2, target)
    vm2 = relax.VirtualMachine(ex2, tvm.cpu())
    x2_tvm = tvm.runtime.tensor(x2.numpy())
    tvm_output2 = vm2["main"](x2_tvm)
    if hasattr(tvm_output2, "numpy"):
        tvm_output2_np = tvm_output2.numpy()
    else:
        tvm_output2_np = tvm_output2[0].numpy()
    assert pytorch_output2.shape == tvm_output2_np.shape
    np.testing.assert_allclose(pytorch_output2.numpy(), tvm_output2_np, rtol=1e-4, atol=1e-5)


def test_dynamic_shape_with_range_constraints():
    class DynamicModel(torch.nn.Module):
        def forward(self, x1, x2):
            return torch.ops.aten.add.Tensor(x1, x2)

    @I.ir_module
    class Expected:
        @R.function
        def main(
            x1: R.Tensor(("s0", 4), dtype="float32"), x2: R.Tensor(("s0", 4), dtype="float32")
        ) -> R.Tuple(R.Tensor(("s0", 4), dtype="float32")):
            s0 = T.int64(is_size_var=True)
            R.func_attr({"tir_var_lower_bound": {"s0": 1}, "tir_var_upper_bound": {"s0": 64}})
            with R.dataflow():
                lv: R.Tensor((s0, 4), dtype="float32") = R.add(x1, x2)
                gv: R.Tuple(R.Tensor((s0, 4), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (torch.randn(8, 4), torch.randn(8, 4))
    batch = torch.export.Dim("batch", min=1, max=64)
    dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}
    exported_program = export(DynamicModel(), args=example_args, dynamic_shapes=dynamic_shapes)

    mod = from_exported_program(exported_program)
    tvm.ir.assert_structural_equal(mod, Expected)


def test_dynamic_shape_with_addition_constraints():
    class ConcatModel(torch.nn.Module):
        def forward(self, x, y):
            return torch.cat([x, y], dim=0)

    @I.ir_module
    class Expected:
        @R.function
        def main(
            x: R.Tensor(("s0", 4), dtype="float32"), y: R.Tensor(("s0___1", 4), dtype="float32")
        ) -> R.Tuple(R.Tensor(("s0 + s0___1", 4), dtype="float32")):
            s0 = T.int64(is_size_var=True)
            s0___1 = T.int64(is_size_var=True)
            R.func_attr(
                {
                    "tir_var_lower_bound": {"s0": 1, "s0___1": 2},
                    "tir_var_upper_bound": {"s0": 64, "s0___1": 65},
                }
            )
            with R.dataflow():
                lv: R.Tensor((s0 + s0___1, 4), dtype="float32") = R.concat((x, y), axis=0)
                gv: R.Tuple(R.Tensor((s0 + s0___1, 4), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    batch = torch.export.Dim("batch", min=1, max=64)
    example_args = (torch.randn(8, 4), torch.randn(9, 4))
    dynamic_shapes = {"x": {0: batch}, "y": {0: batch + 1}}
    exported_program = export(ConcatModel(), args=example_args, dynamic_shapes=dynamic_shapes)

    mod = from_exported_program(exported_program)
    tvm.ir.assert_structural_equal(mod, Expected, map_free_vars=True)


def test_dynamic_shape_with_subtraction_constraints():
    class ConcatModel(torch.nn.Module):
        def forward(self, x, y):
            return torch.cat([x, y], dim=0)

    @I.ir_module
    class Expected:
        @R.function
        def main(
            x: R.Tensor(("s1___1", 4), dtype="float32"), y: R.Tensor(("s1", 4), dtype="float32")
        ) -> R.Tuple(R.Tensor(("s1___1 + s1", 4), dtype="float32")):
            s1___1 = T.int64(is_size_var=True)
            s1 = T.int64(is_size_var=True)
            R.func_attr(
                {
                    "tir_var_lower_bound": {"s1": 0, "s1___1": 1},
                    "tir_var_upper_bound": {"s1": 63, "s1___1": 64},
                }
            )
            with R.dataflow():
                lv: R.Tensor((s1___1 + s1, 4), dtype="float32") = R.concat((x, y), axis=0)
                gv: R.Tuple(R.Tensor((s1___1 + s1, 4), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    batch = torch.export.Dim("batch", min=1, max=64)
    example_args = (torch.randn(8, 4), torch.randn(7, 4))
    dynamic_shapes = {"x": {0: batch}, "y": {0: batch - 1}}
    exported_program = export(ConcatModel(), args=example_args, dynamic_shapes=dynamic_shapes)

    mod = from_exported_program(exported_program)
    tvm.ir.assert_structural_equal(mod, Expected, map_free_vars=True)


def test_dynamic_shape_with_multiplication_constraints():
    class ConcatModel(torch.nn.Module):
        def forward(self, x, y):
            return torch.cat([x, y], dim=0)

    @I.ir_module
    class Expected:
        @R.function
        def main(
            x: R.Tensor(("s0", 4), dtype="float32"), y: R.Tensor(("s0_2", 4), dtype="float32")
        ) -> R.Tuple(R.Tensor(("s0 + s0_2", 4), dtype="float32")):
            s0 = T.int64(is_size_var=True)
            s0_2 = T.int64(is_size_var=True)
            R.func_attr(
                {
                    "tir_var_lower_bound": {"s0": 1, "s0_2": 2},
                    "tir_var_upper_bound": {"s0": 64, "s0_2": 128},
                }
            )
            with R.dataflow():
                lv: R.Tensor((s0 + s0_2, 4), dtype="float32") = R.concat((x, y), axis=0)
                gv: R.Tuple(R.Tensor((s0 + s0_2, 4), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    batch = torch.export.Dim("batch", min=1, max=64)
    example_args = (torch.randn(8, 4), torch.randn(16, 4))
    dynamic_shapes = {"x": {0: batch}, "y": {0: batch * 2}}
    exported_program = export(ConcatModel(), args=example_args, dynamic_shapes=dynamic_shapes)

    mod = from_exported_program(exported_program)
    tvm.ir.assert_structural_equal(mod, Expected, map_free_vars=True)


def test_sym_size_int():
    class SymSizeInt(Module):
        def __init__(self, dim):
            super().__init__()
            self.dim = dim

        def forward(self, x):
            # TODO(@mshr-h): `torch.ops.aten.sym_size.int(x, self.dim)` would be ideal, but currently
            # the ep frontend is not able to handle it.
            return torch.add(x[0], torch.ops.aten.sym_size.int(x, self.dim))

    @I.ir_module
    class Expected1:
        @R.function
        def main(
            x: R.Tensor((1, 3, 4), dtype="float32")
        ) -> R.Tuple(R.Tensor((3, 4), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((3, 4), dtype="float32") = R.take(
                    x, R.const(0, "int64"), axis=0, mode="fast"
                )
                lv1: R.Tensor((3, 4), dtype="float32") = R.add(lv, R.const(3.0, "float32"))
                gv: R.Tuple(R.Tensor((3, 4), dtype="float32")) = (lv1,)
                R.output(gv)
            return gv

    example_args_1 = (torch.randn(1, 3, 4),)
    verify_model(SymSizeInt(dim=1), example_args_1, {}, Expected1)
    verify_model(SymSizeInt(dim=-2), example_args_1, {}, Expected1)

    class SymSizeIntDynamic(Module):
        def __init__(self, dim):
            super().__init__()
            self.dim = dim

        def forward(self, x):
            shape_dim = torch.ops.aten.sym_size.int(x, self.dim)
            return x.reshape(shape_dim, -1)

    @I.ir_module
    class Expected2:
        @R.function
        def main(
            x: R.Tensor(("s0", 3, 4), dtype="float32")
        ) -> R.Tuple(R.Tensor(("s0", 12), dtype="float32")):
            s0 = T.int64(is_size_var=True)
            with R.dataflow():
                lv: R.Tensor((s0, 12), dtype="float32") = R.reshape(x, R.shape([s0, 12]))
                gv: R.Tuple(R.Tensor((s0, 12), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args_2 = (torch.randn(2, 3, 4),)
    dynamic_shapes = {"x": {0: torch.export.Dim("dim")}}
    verify_model(
        SymSizeIntDynamic(dim=0), example_args_2, {}, Expected2, dynamic_shapes=dynamic_shapes
    )


def test_exponential():
    class Exponential(Module):
        def forward(self, x):
            return x.exponential_()

    @I.ir_module
    class Expected:
        @R.function
        def main(
            x: R.Tensor((4, 8), dtype="float32")
        ) -> R.Tuple(R.Tensor((4, 8), dtype="float32"), R.Tensor((4, 8), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((4, 8), dtype="float32") = R.zeros_like(x, dtype="void")
                gv: R.Tuple(
                    R.Tensor((4, 8), dtype="float32"), R.Tensor((4, 8), dtype="float32")
                ) = (lv, lv)
                R.output(gv)
            return gv

    example_args = (torch.randn(4, 8, dtype=torch.float32),)
    verify_model(Exponential(), example_args, {}, Expected)


def test_max_dim():
    class MaxDim1(Module):
        def forward(self, x):
            return torch.max(x, dim=1)

    class MaxDim2(Module):
        def forward(self, x):
            return torch.max(x, dim=1, keepdim=True)

    @I.ir_module
    class expected1:
        @R.function
        def main(
            x: R.Tensor((4, 8, 16), dtype="float32")
        ) -> R.Tuple(R.Tensor((4, 16), dtype="float32"), R.Tensor((4, 16), dtype="int64")):
            with R.dataflow():
                lv: R.Tuple(
                    R.Tensor((4, 1, 16), dtype="float32"), R.Tensor((4, 1, 16), dtype="int64")
                ) = R.topk(x, k=1, axis=1, ret_type="both", largest=True, dtype="int64")
                lv1: R.Tensor((4, 1, 16), dtype="float32") = lv[0]
                lv2: R.Tensor((4, 16), dtype="float32") = R.squeeze(lv1, axis=[1])
                lv3: R.Tensor((4, 1, 16), dtype="int64") = lv[1]
                lv4: R.Tensor((4, 16), dtype="int64") = R.squeeze(lv3, axis=[1])
                lv5: R.Tuple(
                    R.Tensor((4, 16), dtype="float32"), R.Tensor((4, 16), dtype="int64")
                ) = (lv2, lv4)
                lv6: R.Tensor((4, 16), dtype="float32") = lv5[0]
                lv7: R.Tensor((4, 16), dtype="int64") = lv5[1]
                gv: R.Tuple(
                    R.Tensor((4, 16), dtype="float32"), R.Tensor((4, 16), dtype="int64")
                ) = (lv6, lv7)
                R.output(gv)
            return gv

    @I.ir_module
    class expected2:
        @R.function
        def main(
            x: R.Tensor((4, 8, 16), dtype="float32")
        ) -> R.Tuple(R.Tensor((4, 1, 16), dtype="float32"), R.Tensor((4, 1, 16), dtype="int64")):
            with R.dataflow():
                lv: R.Tuple(
                    R.Tensor((4, 1, 16), dtype="float32"), R.Tensor((4, 1, 16), dtype="int64")
                ) = R.topk(x, k=1, axis=1, ret_type="both", largest=True, dtype="int64")
                lv1: R.Tensor((4, 1, 16), dtype="float32") = lv[0]
                lv2: R.Tensor((4, 1, 16), dtype="int64") = lv[1]
                lv3: R.Tuple(
                    R.Tensor((4, 1, 16), dtype="float32"), R.Tensor((4, 1, 16), dtype="int64")
                ) = (lv1, lv2)
                lv4: R.Tensor((4, 1, 16), dtype="float32") = lv3[0]
                lv5: R.Tensor((4, 1, 16), dtype="int64") = lv3[1]
                gv: R.Tuple(
                    R.Tensor((4, 1, 16), dtype="float32"), R.Tensor((4, 1, 16), dtype="int64")
                ) = (lv4, lv5)
                R.output(gv)
            return gv

    example_args = (torch.randn(4, 8, 16, dtype=torch.float32),)
    verify_model(MaxDim1(), example_args, {}, expected1)
    verify_model(MaxDim2(), example_args, {}, expected2)


def test_alias():
    class Alias(Module):
        def forward(self, x):
            return torch.ops.aten.alias(x)

    @I.ir_module
    class Expected:
        @R.function
        def main(
            x: R.Tensor((4, 8), dtype="float32")
        ) -> R.Tuple(R.Tensor((4, 8), dtype="float32")):
            with R.dataflow():
                gv: R.Tuple(R.Tensor((4, 8), dtype="float32")) = (x,)
                R.output(gv)
            return gv

    example_args = (torch.randn(4, 8, dtype=torch.float32),)
    verify_model(Alias(), example_args, {}, Expected)


def test_scatter_value():
    class ScatterValue(Module):
        def forward(self, x, index):
            return x.scatter(1, index, 0.5)

    @I.ir_module
    class Expected:
        @R.function
        def main(
            x: R.Tensor((4, 8), dtype="float32"),
            index: R.Tensor((4, 2), dtype="int64"),
        ) -> R.Tuple(R.Tensor((4, 8), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((4, 2), dtype="float32") = R.broadcast_to(
                    R.const(0.5, "float32"), R.shape([4, 2])
                )
                lv1: R.Tensor((4, 8), dtype="float32") = R.scatter_elements(x, index, lv, axis=1)
                gv: R.Tuple(R.Tensor((4, 8), dtype="float32")) = (lv1,)
                R.output(gv)
            return gv

    example_args = (
        torch.randn(4, 8, dtype=torch.float32),
        torch.randint(0, 8, (4, 2), dtype=torch.int64),
    )
    verify_model(ScatterValue(), example_args, {}, Expected)


def test_grid_sample():
    class GridSample(Module):
        def forward(self, input, grid):
            return torch.nn.functional.grid_sample(
                input, grid, mode="bilinear", padding_mode="zeros", align_corners=True
            )

    @tvm.script.ir_module
    class expected:
        @R.function
        def main(
            input_1: R.Tensor((1, 3, 4, 4), dtype="float32"),
            grid: R.Tensor((1, 2, 2, 2), dtype="float32"),
        ) -> R.Tuple(R.Tensor((1, 3, 2, 2), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((1, 3, 2, 2), dtype="float32") = R.image.grid_sample(
                    input_1,
                    grid,
                    method="bilinear",
                    layout="NCHW",
                    padding_mode="zeros",
                    align_corners=True,
                )
                gv: R.Tuple(R.Tensor((1, 3, 2, 2), dtype="float32")) = (lv,)
                R.output(gv)
            return gv

    example_args = (
        torch.randn(1, 3, 4, 4, dtype=torch.float32),
        torch.randn(1, 2, 2, 2, dtype=torch.float32),
    )
    verify_model(GridSample(), example_args, {}, expected)


if __name__ == "__main__":
    tvm.testing.main()
